본문 바로가기

PS - 알고리즘

SOS dp

 

sum over subset이라는 이름의 dp입니다.

 

음이 아닌 정수로 이루어진 길이  $N$의 수열  $A_1$,  $A_2$, ... , $A_N$이 주어졌을 때 $0 \le mask < (2^k)$ 인 mask에 대해 $mask | A_i = mask$인 $A_i$의 개수를 구합니다(  |  는 bit연산의 or입니다.). 만약 수열 $A$가 0,1,5,6이고 $mask$가 5면 $0|5=5, 1|5=5, 5|5=5, 6|5=7$이므로 답은 3입니다. $mask$가 0일 때 부터 $(2^k) -1$일때 까지의 답을 모두 구하는 문제입니다. $k$는 원하는 값으로 잡으시면 됩니다.

 

$A=(0,1,5,5,6), k=3$이라고 합시다. 숫자끼리 그룹을 지어 줄 것입니다. 먼저, i번째 그룹에는 숫자 i만 들어 있습니다.

 

$DP[i] = $i에 해당하는 그룹 내에 있는 숫자들 중 i와 OR 연산을 했을 때 i가 되는 수들의 개수

라고 정의합시다. 처음에는 i밖에 없으므로 숫자의 개수가 DP값이 될 겁니다.

 

초기 상태

이제 인접한 그룹끼리 묶어 줄 겁니다. 

 

첫번째 단계

 

그룹 0-1에 집중합시다. 0 | 1 = 1이므로 DP[1]에 DP[0]을 더해줍니다. 0은 아무 변화 없습니다.

이것을 그룹 2-3,4-5,6-7에 대해서도 해주면 2|3=3, 4|5=5,6|7=7이므로 dp[3], dp[5], dp[7]에 각각 dp[2], dp[4], dp[6]을 더해줍니다.

 

현재 dp에 들어있는 값 : 

$c[i] = A$에 들어있는 i의 개수라고 할 때

dp[0]=c[0]

dp[1]=c[0] + c[1]

dp[2]=c[2]

dp[3]=c[2] + c[3]

dp[4]=c[4]

dp[5]=c[4] + c[5]

dp[6]=c[6]

dp[7]=c[6] + c[7]

 

이 행동을 모든 숫자가 한 그룹으로 묶일 때 까지 진행합니다.

 

두번째 단계

OR 연산을 조금 더 잘 보이게 하기 위해 비트를 추가해 보았습니다. 첫번째 자리가 2개씩 반복되는 것을 볼 수 있습니다.

0|2=2, 1|3=3, 4|6=6, 5|7=7이므로 dp의 2,3,6,7번째 위치에 dp의 0,1,4,5번째의 값을 더해줍니다. 0|3=3이지만 이 값은 이미 이 전 단계에서 dp[1]에 포함되었으므로 고려하지 않아도 됩니다.

현재 dp에 들어있는 값 : 

$c[i] = A$에 들어있는 i의 개수라고 할 때

dp[0]=c[0]

dp[1]=c[0] + c[1]

dp[2]=c[0] + c[2]

dp[3]=c[0] + c[1] + c[2] + c[3]

dp[4]=c[4]

dp[5]=c[4] + c[5]

dp[6]=c[4] + c[6]

dp[7]=c[4] + c[5] + c[6] + c[7]

마지막 단계

마지막 단계는 처음 두 자리가 4개씩 반복됩니다.

dp[6]에 집중해 봅시다. 지금 단계에서 dp[6]의 정의는

 

$A$에 있는 숫자들 중 0~7인 숫자들만 남긴 다음 6과 OR 연산을 했을 때 결과가 6인 수들의 개수

 

입니다. 이것은

 

(1) $A$에 있는 숫자들 중 0~3인 숫자들만 남긴 다음 6과 OR 연산을 했을 때 결과가 6인 수들의 개수

(2) $A$에 있는 숫자들 중 4~7인 숫자들만 남긴 다음 6과 OR 연산을 했을 때 결과가 6인 수들의 개수

 

의 합으로 구할 수 있습니다. (2)는 그냥 dp[6]입니다. 0~3까지의 숫자들은 3번째 비트가 없습니다. 따라서

 

(1) $A$에 있는 숫자들 중 0~3인 숫자들만 남긴 다음 6과 OR 연산을 했을 때 결과가 2인 수들의 개수

 

로 바꿀 수 있고 이는 dp[2]입니다.

 

같은 원리로 dp[4~7]을 구할 수 있습니다.

 

#define _CRT_SECURE_NO_WARNINGS
#include <iostream>
#include <cstring>
#include <map>
#include <algorithm>
using namespace std;
int b[(1 << 20)];
int main() {
    int n, i,j,k;
    cin >> n >> k;
    for (i = 0;i < n;i++) {
        int val;
        cin >> val;
        b[val]++;
    }

    for (i = 1;i < (1 << k);i <<= 1) {
        for (j = 0;j < (1 << k);j++) {
            if ((j & i) != 0) b[j] += b[j - i];
        }
    }

    for (i = 0;i < (1 << k);i++) {
        printf("%d ", b[i]);
    }
    return 0;
}

'PS - 알고리즘' 카테고리의 다른 글

백준 1129 - 키  (0) 2024.04.12
백준 13573 - 동전 뒤집기 3  (1) 2024.03.23
백준 14854 - 이항 계수 6  (0) 2024.03.16
x/y (mod n) 계산  (3) 2024.03.16
x^n 을 O(log n)에 구하기  (0) 2024.03.16