신기한 방법으로 풀었던 문제 2

최근 참가했던 코드포스에서 또 그 방법을 사용하고 말았다.

Count the Arrays

N, M(2 ≤ N < M ≤ 2 × 105)이 주어졌을 때, 아래 조건을 만족하는 배열의 개수를 구하는 문제이다.

  • 배열의 크기는 N
  • 배열에 들어있는 수는 1보다 크거나 같고, M보다 작거나 같은 정수
  • 배열에서 수 하나는 두 번 들어 있어야 함
  • j < i일때 A[j] < A[j+1]이고, j > i일때 A[j] > A[j+1]을 만족하는 i가 존재해야 한다.

N = 3, M = 4인 경우 가능한 정답은 6개이고, 다음과 같다.

  • [1, 2, 1]
  • [1, 3, 1]
  • [1, 4, 1]
  • [2, 3, 2]
  • [2, 4, 2]
  • [3, 4, 4]

항상 하던대로 브루트 포스를 이용해서 모든 정답을 다 만들고, 정답에서 규칙을 찾아보기로 했다.

#include <iostream>
#include <vector>
using namespace std;
int check(vector<int> a) {
    int n = a.size();
    bool acc = true;
    for (int i=0; i<n-1; i++) {
        if (a[i] == a[i+1]) return 0;
        if (acc) {
            if (a[i] < a[i+1]) {
            } else {
                acc = false;
            }
        } else {
            if (a[i] < a[i+1]) return 0;
        }
    }
    return 1;
}
long long go(int n, int m, vector<int> &a, int same, int index) {
    if (a.size() == n) {
        long long sum = 0;
        vector<int> b(a);
        sort(b.begin(),b.end());
        do {
            sum += check(b);
        } while(next_permutation(b.begin(), b.end()));
        return sum;
    }
    if (index > m) return 0;
    if (same == index) {
        return go(n, m, a, same, index+1);
    }
    int t1 = go(n, m, a,same, index+1);
    a.push_back(index);
    int t2 = go(n, m, a, same, index+1);
    a.pop_back();
    return t1 + t2;
}
long long go(int n, int m) {
    long long cnt = 0;
    for (int same=1; same<=m; same++) {
        vector<int> a;
        a.push_back(same);
        a.push_back(same);
        cnt += go(n, m, a, same, 1);
    }
    return cnt;
}
int main() {
    for (int n=2; n<=10; n++) {
        for (int m=n; m<=10; m++) {
            cout << n << ' ' << m << ' ' << go(n, m) << '\n';
        }
    }
    return 0;
}

go(n, m)은 정답을 구하는 소스이고, 문제에서 수 하나는 두 번 들어있다고 했기 때문에, same을 이용해서 그 수를 미리 정해놓았다.

그 다음, go(n, m, a, same, index)를 이용해서 배열에 들어가야 할 수를 모두 정했다.

samego(n, m)에서 미리 두 번 넣었기 때문에, 이 경우는 go(n, m, a, same, index)에서 제외하게 구현했다. 수를 모두 고른 후(a.size() == n)에는 모든 순서를 next_permutation을 통해서 만들어보면서 문제의 정답을 구했다.

2 ≤ N < M ≤ 10일때 답을 모두 출력해보니 다음과 같았다.

2 2 0
2 3 0
2 4 0
2 5 0
2 6 0
2 7 0
2 8 0
2 9 0
2 10 0

3 3 3
3 4 6
3 5 10
3 6 15
3 7 21
3 8 28
3 9 36
3 10 45

4 4 16
4 5 40
4 6 80
4 7 140
4 9 336
4 10 480

5 5 60
5 6 180
5 7 420
5 8 840
5 9 1512
5 10 2520

6 6 192
6 7 672
6 8 1792
6 9 4032
6 10 8064

7 7 560
7 8 2240
7 9 6720
7 10 16800

8 8 1536
8 9 6912
8 10 23040

9 9 4032
9 10 20160

10 10 10240

N = 2인 경우에는 답이 0이라는 사실을 알 수 있다. 3 ≤ N에 대해서 답을 구해야 한다. 수를 쳐다보니 (N, M)일때 정답은 (N, M-1)의 정답보다 크고, 곱해서 만들 수 있을 것 같았다.

3 3 3
3 4 6 = 3 * 2
3 5 10 = 6 * 1.6666666667
3 6 15 = 10 * 1.5
3 7 21 = 15 * 1.4
3 8 28 = 21 * 1.3333333333
3 9 36 = 28 * 1.2857142857
3 10 45 = 36 * 1.25

4 4 16
4 5 40
4 6 80
4 7 140
4 8 224
4 9 336
4 10 480

5 5 60
5 6 180 = 60 * 3
5 7 420 = 180 * 2.3333333333
5 8 840 = 420 * 2
5 9 1512 = 840 * 1.8
5 10 2520 = 1512 * 1.6666666667

6 6 192
6 7 672 = 192 * 3.5
6 8 1792 = 672 * 2.6666666667
6 9 4032 = 1792 * 2.25
6 10 8064 = 4032 * 2

7 7 560
7 8 2240 = 560 * 4
7 9 6720 = 2240 * 3
7 10 16800 = 6720 * 2.5

8 8 1536
8 9 6912
8 10 23040

9 9 4032
9 10 20160

10 10 10240

곱해지는 값이 무작위 값이 아니라는 사실을 알 수 있다. 정수가 곱해지기도 하고, 소수도 나오는데 이 소수는 분수로 표현할 수 있는 소수이다. 곱해지는 값을 분수로 바꿔보기로 했다.

3 3 3
3 4 6 = 3 * 2
3 5 10 = 6 * 1.6666666667 = 6 * 5/3
3 6 15 = 10 * 1.5 = 10 * 3/2
3 7 21 = 15 * 1.4 = 15 * 7/5
3 8 28 = 21 * 1.3333333333 = 21 * 4/3
3 9 36 = 28 * 1.2857142857 = 28 * 9/7
3 10 45 = 36 * 1.25 = 36 * 5/4

5 5 60
5 6 180 = 60 * 3
5 7 420 = 180 * 2.3333333333 = 180 * 7/3
5 8 840 = 420 * 2
5 9 1512 = 840 * 1.8 = 840 * 9/5
5 10 2520 = 1512 * 1.6666666667 = 1512 * 5/3

6 6 192
6 7 672 = 192 * 3.5 = 192 * 7/2
6 8 1792 = 672 * 2.6666666667 = 672 * 8/3
6 9 4032 = 1792 * 2.25 = 1792 * 9/4
6 10 8064 = 4032 * 2

항상 분수의 형태가 나오는 것을 확인했다. N = 3일때를 보면 5/3, 7/5, 9/7가 두 번씩 건너가면서 등장한다. N = 6일때는 7/2, 8/3, 9/4이다.

3 3 3
3 4 6 = 3 * 2 = 3 * 4/2
3 5 10 = 6 * 1.6666666667 = 6 * 5/3
3 6 15 = 10 * 1.5 = 10 * 3/2 = 10 * 6/4
3 7 21 = 15 * 1.4 = 15 * 7/5
3 8 28 = 21 * 1.3333333333 = 21 * 4/3 = 21 * 8/6
3 9 36 = 28 * 1.2857142857 = 28 * 9/7
3 10 45 = 36 * 1.25 = 36 * 5/4 = 36 * 10/8

(N, M) = (3, 5)인 경우 정답은 (3, 3)의 정답에 4/2, 5/3을 곱하면 되고, (3, 9)일때 정답은 (3, 3)의 정답에 4/2, 5/3, 6/4, 7/5, 8/6, 9/7을 곱하면 된다.

이렇게 일관성을 갖는다면 N == M일때도 규칙을 만족할 수 있어야 한다. (3, 4)가 (3, 3)의 정답에 4/2를 곱해서 구하는 것이니, (3, 3)도 어떤 정답에 3/1을 곱해서 구할 수 있어야 한다.

N == M인 경우에 대해서 모두 구해보았다.

3 3 3 = 1 * 3

4 4 16 = 4 * 4

5 5 60 = 12 * 5

6 6 192 = 32 * 6

7 7 560 = 80 * 7

8 8 1536 = 192 * 8

9 9 4032 = 448 * 9

10 10 10240 = 1024 * 10

예상대로 어떤 값에 N/1을 곱해서 N == M인 경우도 구할 수 있었다.

OEIS를 켜고 1, 14, 12, 32, 80, 192, 448, 1024를 입력했더니 일반항을 구할 수 있었다.

구한 일반항 a(n) = n*2^(n-1)이고, 실제로 n = 3부터 시작하니 모든 n에서 2를 빼야 한다.

N == M인 경우 (N-2)*2^(N-3)에 N을 곱한 것이 정답이다.

일반화해보면 N ≤ M의 경우에는 (N-2)*2^(N-3) * (N/1) * ((N+1)/2) * … * (M/(M-N+1)) 이 정답이다.

#!/usr/bin/env python3

MOD = 998244353
n,m = map(int,input().split())

def go(a, b, c):
    ans = 1
    while b > 0:
        if b%2 == 1:
            ans *= a
            ans %= c
        b //= 2
        a *= a
        a %= c
    return ans

if n == 2:
    print(0)
else:
    start = (n-2)*go(2, n-3, MOD)
    start %= MOD
    bunja = 1
    bunmo = 1
    for i in range(n, m+1):
        bunja *= i
        bunmo *= (i-n+1)
        bunja %= MOD
        bunmo %= MOD

    #ans = start * bunja // bunmo
    start *= bunja
    temp = go(bunmo, MOD-2, MOD)
    ans = start * temp
    ans %= MOD
    print(ans)

go(a, b, c)는 a의 b제곱을 c로 나눈 나머지를 구하는 함수이다.

start에 OEIS로 구한 초기값을 넣어주고, 분자(bunja), 분모(bunmo)의 값을 모두 구한 다음, ans = start * bunja / bunmo 로 올바른 답을 구할 수 있다.

A / B mod C는 C가 소수라면 A * B^(C-2) mod C와 같기 때문에, go(bunmo, MOD-2, MOD)를 이용해서 나머지 연산을 수행했다.

Ayoub’s function

0과 1로 이루어진 문자열 s가 있을 때, f(s)는 ‘1’을 포함하고 있는 부분 문자열의 개수이다. 예를 들어, s = “01010”인 경우 f(s) = 12이다.

길이가 n이고, ‘1’이 m개 들어있는 문자열 s 중에서 f(s)의 최댓값을 찾는 문제이다. 1 ≤ n ≤ 109, 0 ≤ m ≤ n이다.

일단, n ≤ 10에 대해서 모든 답을 다 구해보기로 했다.

#include <iostream>
#include <string>
using namespace std;
int go(int n, int m, string s) {
    if (m < 0) {
        return 0;
    }
    if (n == 0) {
        int cnt = 0;
        for (int i=0; i<s.length(); i++) {
            bool ok = false;
            for (int j=i; j<s.length(); j++) {
                if (s[j] == '1') ok = true;
                if (ok) cnt += 1;
            }
        }
        return cnt;
    }
    int t1 = go(n-1, m, s+"0");
    int t2 = go(n-1, m-1, s+"1");
    return max(t1,t2);
}
int main() {
    for (int n=1; n<=10; n++) {
        for (int m=0; m<=n; m++) {
            cout << n << ' ' << m << ' ' << go(n, m, "") << '\n';
        }
    }
    return 0;
}
1 0 0
1 1 1

2 0 0
2 1 2
2 2 3

3 0 0
3 1 4
3 2 5
3 3 6

4 0 0
4 1 6
4 2 8
4 3 9
4 4 10

5 0 0
5 1 9
5 2 12
5 3 13
5 4 14
5 5 15

6 0 0
6 1 12
6 2 16
6 3 18
6 4 19
6 5 20
6 6 21

7 0 0
7 1 16
7 2 21
7 3 24
7 4 25
7 5 26
7 6 27
7 7 28

8 0 0
8 1 20
8 2 27
8 3 30
8 4 32
8 5 33
8 6 34
8 7 35
8 8 36

9 0 0
9 1 25
9 2 33
9 3 37
9 4 40
9 5 41
9 6 42
9 7 43
9 8 44
9 9 45

10 0 0
10 1 30
10 2 40
10 3 45
10 4 48
10 5 50
10 6 51
10 7 52
10 8 53
10 9 54
10 10 55

M = 0인 경우에는 답이 0인 것을 알 수 있다. 이 사실은 굳이 답을 모두 구해보지 않고도 알 수 있다. 놀라운 사실을 하나 알 수 있다. 바로 N/2 ≤ M인 경우이다. N/2의 정답부터 점점 1씩 증가한다는 사실을 알 수 있다.

N = 10인 경우를 보면, (N, M)이 (10, 5)일때 정답에 1을 더하면 (10, 6)의 답, 여기서 다시 1을 더하면 (10, 7)의 답, … 임을 알 수 있다.

N = 9와 같이 홀수인 경우일때도 N/2 ≤ M을 만족한다. N == M일때를 보면 익숙한 수가 눈에 보인다.

1 1 1
2 2 3
3 3 6
4 4 10
5 5 15
6 6 21
7 7 28
8 8 36
9 9 45
10 10 55

1, 3, 6, 10, 15, 21, 28, 36, 45, 55는 1부터 N까지의 합과 같다. N == M일때의 정답은 N(N+1)/2와 같다는 사실을 알 수 있다.

위에서 N/2 ≤ M일때의 규칙과 합쳐보면, N/2 ≤ M일때 정답은 N(N+1)/2 – (N-M)과 같다는 사실을 알 수 있다.

이제 우리가 해야 하는 것은 N/2 > M인 경우이다. 규칙을 찾으려면 식이 많이 있는 것이 좋다. 8 ≤ N ≤ 15인 경우에 대해서 답을 모두 구해보기로 했다.

8 1 20
8 2 27
8 3 30
8 4 32

9 1 25
9 2 33
9 3 37
9 4 40

10 1 30
10 2 40
10 3 45
10 4 48
10 5 50

11 1 36
11 2 48
11 3 54
11 4 57
11 5 60

12 1 42
12 2 56
12 3 63
12 4 67
12 5 70
12 6 72

13 1 49
13 2 65
13 3 73
13 4 78
13 5 81
13 6 84

14 1 56
14 2 75
14 3 84
14 4 90
14 5 93
14 6 96
14 7 98

15 1 64
15 2 85
15 3 96
15 4 102
15 5 106
15 6 109
15 7 112

M = 1일때 규칙을 찾아냈는데, 이후 M > 1일때 이 값을 어떻게 활용해야 하는지 계산하기가 매우 어렵다.

8 1 20 = 4 * 5

9 1 25 = 5 * 5

10 1 30 = 5 * 6

11 1 36 = 6 * 6

12 1 42 = 6 * 7

13 1 49 = 7 * 7

14 1 56 = 7 * 8

15 1 64 = 8 * 8

N = 14, 15인 경우에 M의 변화에 따른 답의 차이를 구해보았다.

14 1 56
14 2 75 = 56 + 19
14 3 84 = 75 + 9
14 4 90 = 84 + 6
14 5 93 = 90 + 3
14 6 96 = 93 + 3
14 7 98 = 96 + 2

15 1 64
15 2 85 = 64 + 21
15 3 96 = 85 + 11
15 4 102 = 96 + 6
15 5 106 = 102 + 4
15 6 109 = 106 + 3
15 7 112 = 109 + 3

답의 차이로는 마땅한 규칙을 찾을 수 없다. 정답일때 문자열을 함께 출력해보기로 했다.

#include <iostream>
#include <string>
using namespace std;
pair<int,string> go(int n, int m, string s) {
    if (m < 0) {
        return make_pair(0,"");
    }
    if (n == 0) {
        int cnt = 0;
        for (int i=0; i<s.length(); i++) {
            bool ok = false;
            for (int j=i; j<s.length(); j++) {
                if (s[j] == '1') ok = true;
                if (ok) cnt += 1;
            }
        }
        return make_pair(cnt,s);
    }
    auto t1 = go(n-1, m, s+"0");
    auto t2 = go(n-1, m-1, s+"1");
    return max(t1,t2);
}
int main() {
    for (int n=8; n<=15; n++) {
        for (int m=1; m<=n/2; m++) {
            auto p = go(n, m, "");
            cout << n << ' ' << m << ' ' << p.first << '\n';
            cout << p.second << '\n';
        }
        cout << '\n';
    }
    return 0;
}

13 ≤ N ≤ 15일때를 보니 다음과 같다.

13 1 49
0000001000000
13 2 65
0001000010000
13 3 73
0010010001000
13 4 78
0100100100100
13 5 81
0101010100100
13 6 84
0101010101010

14 1 56
00000010000000
14 2 75
00001000010000
14 3 84
00100010001000
14 4 90
00100100100100
14 5 93
01010100100100
14 6 96
01010101010100
14 7 98
10101010101010

15 1 64
000000010000000
15 2 85
000010000100000
15 3 96
000100010001000
15 4 102
001001001001000
15 5 106
010100100100100
15 6 109
010101010100100
15 7 112
010101010101010

f(s)의 값이 최대가 될때 1의 위치가 규칙성이 있다.

N = 15일때를 보면

  • M = 1: 000000010000000, (0 7개) 1 (0 7개)
  • M = 2: 000010000100000, (0 4개) 1 (0 4개) 1 (0 5개)
  • M = 3: 000100010001000, (0 3개) 1 (0 3개) 1 (0 3개) 1 (0 3개)
  • M = 4: 001001001001000, (0 2개) 1 (0 2개) 1 (0 2개) 1 (0 2개) 1 (0 3개)

문자열은 0과 1로만 이루어져 있고, 1이 0을 나누는 역할을 한다고 볼 수 있다. 이때 연속하는 0의 개수가 같거나 1 차이가 나야 한다.

1로인해서 나누어지는 0을 그룹이라고 했을 때, f(s)가 최대가 되는 경우는 0 N-M개를 M+1개의 그룹으로 나누는 경우와 같다고 볼 수 있다.

N-M이 M+1로 나누어 떨어지는 경우가 아니면 뒤쪽에 있는 그룹의 크기가 1 크면 된다.

위에서 구현한 브루트 포스 알고리즘에서 f(s)값을 구하는 것을 조금 개선해볼 수 있다.

        int cnt = 0;
        for (int i=0; i<s.length(); i++) {
            bool ok = false;
            for (int j=i; j<s.length(); j++) {
                if (s[j] == '1') ok = true;
                if (ok) cnt += 1;
            }
        }

위의 소스는 s의 (i, j)에 1이 있으면 cnt에 1을 증가시키는 소스이다. 한 번 1이 등장하면, 그 뒤의 위치에도 1이 등장한다고 볼 수 있다.

        int cnt = 0;
        for (int i=0; i<s.length(); i++) {
            for (int j=i; j<s.length(); j++) {
                if (s[j] == '1') {
                    cnt += s.length()-j;
                    break;
                }
            }
        }

위의 계산 아이디어와 위에서 발견한 규칙을 합하면 문제를 해결할 수 있다.

N = 15, M = 2인 경우는

  • 000010000100000

이 문제의 정답이다.

  • 000010000100000
  • 000010000100000

i가 빨간색에 있다며, j는 파란색에 있어야 한다. 여기서 5 * 11개를 찾을 수 있다.

  • 000010000100000
  • 000010000100000

여기에서 5 * 6개를 찾을 수 있고,

  • 000010000100000
  • 000010000100000

여기에서는 찾을 수가 없다. 5*11과 5*6을 더하면 85이고, 이 값은 우리가 위해서 구한 값과 같다.

규칙을 좀 더 명확하게 찾기 위해서 좀 더 긴 N으로 예시를 들어보자.

N = 20, M = 5인 경우

  • 00100100100010001000

가 문제의 정답이고, f(s)의 값은 183이다. 위에서 계산한 것 처럼 다시 계산해보자.

  • 00100100100010001000
  • 00100100100010001000

3 * 18

  • 00100100100010001000
  • 00100100100010001000

3 * 15

  • 00100100100010001000
  • 00100100100010001000

3 * 12

  • 00100100100010001000
  • 00100100100010001000

4 * 8

  • 00100100100010001000
  • 00100100100010001000

4 * 4

모두 합해보면 3 * 18 + 3 * 15 + 3 * 12 + 4 * 8 + 4 * 4 = 183이고, 문제의 정답과 같다.

1로 나누어지는 0의 개수는 앞과 뒤로 나눌 수 있고, 앞과 뒤를 따로 계산해야 한다.

0은 N-M개가 있고, 이 0을 총 M+1개의 그룹으로 나누어야 한다. 아래 식에서 /는 정수 나눗셈이다.

  • 앞쪽 0 그룹
    • (N-M) / (M+1)개의 0이 있음
    • 그룹의 개수는 (M+1) – (N-M) % (M+1)개
  • 뒤쪽 0 그룹
    • (N-M) / (M+1) + 1개의 0이 있음
    • 그룹의 개수는 (N-M) % (M+1)개

앞쪽 0 그룹에서 구해지는 (i, j) 쌍을 모두 구해보면 3 * 18, 3 * 15, 3 * 12 이다. 등차수열로 볼 수 있다. 3은 항상 곱해지니, 3을 빼고 합을 계산한 다음, 합에 3을 곱하기로 했다.

식이 너무 길어질 수 있어 K = (N-M) / (M+1), R = (N-M) % (M+1)로 작성했다.

  • 첫 항: N-K
  • 공차: -(K+1)
  • 항의 개수: (M+1) – R

이제 위의 등차수열의 합에서 (K+1)를 곱하면 앞쪽 0 그룹에서 만들 수 있는 (i, j) 쌍의 개수를 구할 수 있다.

뒤쪽 0 그룹도 비슷하게 구할 수 있다.

  • 첫 항: (K+2) * (R-1)
  • 공차: -(K+2)
  • 항의 개수: R – 1

위의 등차수열의 합에서 (K+2)를 곱하면 뒤쪽 0 그룹에서 만들 수 있는 쌍을 모두 구할 수 있다.

#!/usr/bin/env python3
import sys

def calc(a0, diff, n):
    return n * (2*a0 + (n-1)*diff) // 2

def go(n, m):
    if n//2 <= m:
        ans = n*(n+1)//2
        ans -= (n-m)
        return ans
    elif m == 0:
        return 0
    else:
        k = (n-m)//(m+1)
        r = (n-m)%(m+1)
        begin = (k+1) * calc(n-k, -(k+1), (m+1)-r)
        end = (k+2) * calc((k+2)*(r-1), -(k+2), r-1)
        return begin + end

t = int(sys.stdin.readline())
for _ in range(t):
    n, m = map(int,sys.stdin.readline().split())
    print(go(n, m))

begin이 앞쪽 0 그룹, end가 뒤쪽 0 그룹이다.

당연히 위 두 문제의 정해는 여기에 적혀있는 방법이 아니다.