최근 참가했던 코드포스에서 또 그 방법을 사용하고 말았다.
Count the Arrays
- 문제 링크: Codeforces
N, M(2 ≤ N < M ≤ 2 × 105)이 주어졌을 때, 아래 조건을 만족하는 배열의 개수를 구하는 문제이다.
- 배열의 크기는 N
- 배열에 들어있는 수는 1보다 크거나 같고, M보다 작거나 같은 정수
- 배열에서 수 하나는 두 번 들어 있어야 함
- j < i일때 A[j] < A[j+1]이고, j > i일때 A[j] > A[j+1]을 만족하는 i가 존재해야 한다.
N = 3, M = 4인 경우 가능한 정답은 6개이고, 다음과 같다.
- [1, 2, 1]
- [1, 3, 1]
- [1, 4, 1]
- [2, 3, 2]
- [2, 4, 2]
- [3, 4, 4]
항상 하던대로 브루트 포스를 이용해서 모든 정답을 다 만들고, 정답에서 규칙을 찾아보기로 했다.
#include <iostream> #include <vector> using namespace std; int check(vector<int> a) { int n = a.size(); bool acc = true; for (int i=0; i<n-1; i++) { if (a[i] == a[i+1]) return 0; if (acc) { if (a[i] < a[i+1]) { } else { acc = false; } } else { if (a[i] < a[i+1]) return 0; } } return 1; } long long go(int n, int m, vector<int> &a, int same, int index) { if (a.size() == n) { long long sum = 0; vector<int> b(a); sort(b.begin(),b.end()); do { sum += check(b); } while(next_permutation(b.begin(), b.end())); return sum; } if (index > m) return 0; if (same == index) { return go(n, m, a, same, index+1); } int t1 = go(n, m, a,same, index+1); a.push_back(index); int t2 = go(n, m, a, same, index+1); a.pop_back(); return t1 + t2; } long long go(int n, int m) { long long cnt = 0; for (int same=1; same<=m; same++) { vector<int> a; a.push_back(same); a.push_back(same); cnt += go(n, m, a, same, 1); } return cnt; } int main() { for (int n=2; n<=10; n++) { for (int m=n; m<=10; m++) { cout << n << ' ' << m << ' ' << go(n, m) << '\n'; } } return 0; }
go(n, m)
은 정답을 구하는 소스이고, 문제에서 수 하나는 두 번 들어있다고 했기 때문에, same
을 이용해서 그 수를 미리 정해놓았다.
그 다음, go(n, m, a, same, index)
를 이용해서 배열에 들어가야 할 수를 모두 정했다.
same
은 go(n, m)
에서 미리 두 번 넣었기 때문에, 이 경우는 go(n, m, a, same, index)
에서 제외하게 구현했다. 수를 모두 고른 후(a.size() == n
)에는 모든 순서를 next_permutation
을 통해서 만들어보면서 문제의 정답을 구했다.
2 ≤ N < M ≤ 10일때 답을 모두 출력해보니 다음과 같았다.
2 2 0 2 3 0 2 4 0 2 5 0 2 6 0 2 7 0 2 8 0 2 9 0 2 10 0 3 3 3 3 4 6 3 5 10 3 6 15 3 7 21 3 8 28 3 9 36 3 10 45 4 4 16 4 5 40 4 6 80 4 7 140 4 9 336 4 10 480 5 5 60 5 6 180 5 7 420 5 8 840 5 9 1512 5 10 2520 6 6 192 6 7 672 6 8 1792 6 9 4032 6 10 8064 7 7 560 7 8 2240 7 9 6720 7 10 16800 8 8 1536 8 9 6912 8 10 23040 9 9 4032 9 10 20160 10 10 10240
N = 2인 경우에는 답이 0이라는 사실을 알 수 있다. 3 ≤ N에 대해서 답을 구해야 한다. 수를 쳐다보니 (N, M)일때 정답은 (N, M-1)의 정답보다 크고, 곱해서 만들 수 있을 것 같았다.
3 3 3 3 4 6 = 3 * 2 3 5 10 = 6 * 1.6666666667 3 6 15 = 10 * 1.5 3 7 21 = 15 * 1.4 3 8 28 = 21 * 1.3333333333 3 9 36 = 28 * 1.2857142857 3 10 45 = 36 * 1.25 4 4 16 4 5 40 4 6 80 4 7 140 4 8 224 4 9 336 4 10 480 5 5 60 5 6 180 = 60 * 3 5 7 420 = 180 * 2.3333333333 5 8 840 = 420 * 2 5 9 1512 = 840 * 1.8 5 10 2520 = 1512 * 1.6666666667 6 6 192 6 7 672 = 192 * 3.5 6 8 1792 = 672 * 2.6666666667 6 9 4032 = 1792 * 2.25 6 10 8064 = 4032 * 2 7 7 560 7 8 2240 = 560 * 4 7 9 6720 = 2240 * 3 7 10 16800 = 6720 * 2.5 8 8 1536 8 9 6912 8 10 23040 9 9 4032 9 10 20160 10 10 10240
곱해지는 값이 무작위 값이 아니라는 사실을 알 수 있다. 정수가 곱해지기도 하고, 소수도 나오는데 이 소수는 분수로 표현할 수 있는 소수이다. 곱해지는 값을 분수로 바꿔보기로 했다.
3 3 3 3 4 6 = 3 * 2 3 5 10 = 6 * 1.6666666667 = 6 * 5/3 3 6 15 = 10 * 1.5 = 10 * 3/2 3 7 21 = 15 * 1.4 = 15 * 7/5 3 8 28 = 21 * 1.3333333333 = 21 * 4/3 3 9 36 = 28 * 1.2857142857 = 28 * 9/7 3 10 45 = 36 * 1.25 = 36 * 5/4 5 5 60 5 6 180 = 60 * 3 5 7 420 = 180 * 2.3333333333 = 180 * 7/3 5 8 840 = 420 * 2 5 9 1512 = 840 * 1.8 = 840 * 9/5 5 10 2520 = 1512 * 1.6666666667 = 1512 * 5/3 6 6 192 6 7 672 = 192 * 3.5 = 192 * 7/2 6 8 1792 = 672 * 2.6666666667 = 672 * 8/3 6 9 4032 = 1792 * 2.25 = 1792 * 9/4 6 10 8064 = 4032 * 2
항상 분수의 형태가 나오는 것을 확인했다. N = 3일때를 보면 5/3, 7/5, 9/7가 두 번씩 건너가면서 등장한다. N = 6일때는 7/2, 8/3, 9/4이다.
3 3 3 3 4 6 = 3 * 2 = 3 * 4/2 3 5 10 = 6 * 1.6666666667 = 6 * 5/3 3 6 15 = 10 * 1.5 = 10 * 3/2 = 10 * 6/4 3 7 21 = 15 * 1.4 = 15 * 7/5 3 8 28 = 21 * 1.3333333333 = 21 * 4/3 = 21 * 8/6 3 9 36 = 28 * 1.2857142857 = 28 * 9/7 3 10 45 = 36 * 1.25 = 36 * 5/4 = 36 * 10/8
(N, M) = (3, 5)인 경우 정답은 (3, 3)의 정답에 4/2, 5/3을 곱하면 되고, (3, 9)일때 정답은 (3, 3)의 정답에 4/2, 5/3, 6/4, 7/5, 8/6, 9/7을 곱하면 된다.
이렇게 일관성을 갖는다면 N == M일때도 규칙을 만족할 수 있어야 한다. (3, 4)가 (3, 3)의 정답에 4/2를 곱해서 구하는 것이니, (3, 3)도 어떤 정답에 3/1을 곱해서 구할 수 있어야 한다.
N == M인 경우에 대해서 모두 구해보았다.
3 3 3 = 1 * 3 4 4 16 = 4 * 4 5 5 60 = 12 * 5 6 6 192 = 32 * 6 7 7 560 = 80 * 7 8 8 1536 = 192 * 8 9 9 4032 = 448 * 9 10 10 10240 = 1024 * 10
예상대로 어떤 값에 N/1을 곱해서 N == M인 경우도 구할 수 있었다.
OEIS를 켜고 1, 14, 12, 32, 80, 192, 448, 1024를 입력했더니 일반항을 구할 수 있었다.
구한 일반항 a(n) = n*2^(n-1)이고, 실제로 n = 3부터 시작하니 모든 n에서 2를 빼야 한다.
N == M인 경우 (N-2)*2^(N-3)에 N을 곱한 것이 정답이다.
일반화해보면 N ≤ M의 경우에는 (N-2)*2^(N-3) * (N/1) * ((N+1)/2) * … * (M/(M-N+1)) 이 정답이다.
#!/usr/bin/env python3 MOD = 998244353 n,m = map(int,input().split()) def go(a, b, c): ans = 1 while b > 0: if b%2 == 1: ans *= a ans %= c b //= 2 a *= a a %= c return ans if n == 2: print(0) else: start = (n-2)*go(2, n-3, MOD) start %= MOD bunja = 1 bunmo = 1 for i in range(n, m+1): bunja *= i bunmo *= (i-n+1) bunja %= MOD bunmo %= MOD #ans = start * bunja // bunmo start *= bunja temp = go(bunmo, MOD-2, MOD) ans = start * temp ans %= MOD print(ans)
go(a, b, c)
는 a의 b제곱을 c로 나눈 나머지를 구하는 함수이다.
start
에 OEIS로 구한 초기값을 넣어주고, 분자(bunja
), 분모(bunmo
)의 값을 모두 구한 다음, ans = start * bunja / bunmo
로 올바른 답을 구할 수 있다.
A / B mod C는 C가 소수라면 A * B^(C-2) mod C와 같기 때문에, go(bunmo, MOD-2, MOD)
를 이용해서 나머지 연산을 수행했다.
Ayoub’s function
- 문제 링크: Codeforces
0과 1로 이루어진 문자열 s가 있을 때, f(s)는 ‘1’을 포함하고 있는 부분 문자열의 개수이다. 예를 들어, s = “01010”인 경우 f(s) = 12이다.
길이가 n이고, ‘1’이 m개 들어있는 문자열 s 중에서 f(s)의 최댓값을 찾는 문제이다. 1 ≤ n ≤ 109, 0 ≤ m ≤ n이다.
일단, n ≤ 10에 대해서 모든 답을 다 구해보기로 했다.
#include <iostream> #include <string> using namespace std; int go(int n, int m, string s) { if (m < 0) { return 0; } if (n == 0) { int cnt = 0; for (int i=0; i<s.length(); i++) { bool ok = false; for (int j=i; j<s.length(); j++) { if (s[j] == '1') ok = true; if (ok) cnt += 1; } } return cnt; } int t1 = go(n-1, m, s+"0"); int t2 = go(n-1, m-1, s+"1"); return max(t1,t2); } int main() { for (int n=1; n<=10; n++) { for (int m=0; m<=n; m++) { cout << n << ' ' << m << ' ' << go(n, m, "") << '\n'; } } return 0; }
1 0 0 1 1 1 2 0 0 2 1 2 2 2 3 3 0 0 3 1 4 3 2 5 3 3 6 4 0 0 4 1 6 4 2 8 4 3 9 4 4 10 5 0 0 5 1 9 5 2 12 5 3 13 5 4 14 5 5 15 6 0 0 6 1 12 6 2 16 6 3 18 6 4 19 6 5 20 6 6 21 7 0 0 7 1 16 7 2 21 7 3 24 7 4 25 7 5 26 7 6 27 7 7 28 8 0 0 8 1 20 8 2 27 8 3 30 8 4 32 8 5 33 8 6 34 8 7 35 8 8 36 9 0 0 9 1 25 9 2 33 9 3 37 9 4 40 9 5 41 9 6 42 9 7 43 9 8 44 9 9 45 10 0 0 10 1 30 10 2 40 10 3 45 10 4 48 10 5 50 10 6 51 10 7 52 10 8 53 10 9 54 10 10 55
M = 0인 경우에는 답이 0인 것을 알 수 있다. 이 사실은 굳이 답을 모두 구해보지 않고도 알 수 있다. 놀라운 사실을 하나 알 수 있다. 바로 N/2 ≤ M인 경우이다. N/2의 정답부터 점점 1씩 증가한다는 사실을 알 수 있다.
N = 10인 경우를 보면, (N, M)이 (10, 5)일때 정답에 1을 더하면 (10, 6)의 답, 여기서 다시 1을 더하면 (10, 7)의 답, … 임을 알 수 있다.
N = 9와 같이 홀수인 경우일때도 N/2 ≤ M을 만족한다. N == M일때를 보면 익숙한 수가 눈에 보인다.
1 1 1 2 2 3 3 3 6 4 4 10 5 5 15 6 6 21 7 7 28 8 8 36 9 9 45 10 10 55
1, 3, 6, 10, 15, 21, 28, 36, 45, 55는 1부터 N까지의 합과 같다. N == M일때의 정답은 N(N+1)/2와 같다는 사실을 알 수 있다.
위에서 N/2 ≤ M일때의 규칙과 합쳐보면, N/2 ≤ M일때 정답은 N(N+1)/2 – (N-M)과 같다는 사실을 알 수 있다.
이제 우리가 해야 하는 것은 N/2 > M인 경우이다. 규칙을 찾으려면 식이 많이 있는 것이 좋다. 8 ≤ N ≤ 15인 경우에 대해서 답을 모두 구해보기로 했다.
8 1 20 8 2 27 8 3 30 8 4 32 9 1 25 9 2 33 9 3 37 9 4 40 10 1 30 10 2 40 10 3 45 10 4 48 10 5 50 11 1 36 11 2 48 11 3 54 11 4 57 11 5 60 12 1 42 12 2 56 12 3 63 12 4 67 12 5 70 12 6 72 13 1 49 13 2 65 13 3 73 13 4 78 13 5 81 13 6 84 14 1 56 14 2 75 14 3 84 14 4 90 14 5 93 14 6 96 14 7 98 15 1 64 15 2 85 15 3 96 15 4 102 15 5 106 15 6 109 15 7 112
M = 1일때 규칙을 찾아냈는데, 이후 M > 1일때 이 값을 어떻게 활용해야 하는지 계산하기가 매우 어렵다.
8 1 20 = 4 * 5 9 1 25 = 5 * 5 10 1 30 = 5 * 6 11 1 36 = 6 * 6 12 1 42 = 6 * 7 13 1 49 = 7 * 7 14 1 56 = 7 * 8 15 1 64 = 8 * 8
N = 14, 15인 경우에 M의 변화에 따른 답의 차이를 구해보았다.
14 1 56 14 2 75 = 56 + 19 14 3 84 = 75 + 9 14 4 90 = 84 + 6 14 5 93 = 90 + 3 14 6 96 = 93 + 3 14 7 98 = 96 + 2 15 1 64 15 2 85 = 64 + 21 15 3 96 = 85 + 11 15 4 102 = 96 + 6 15 5 106 = 102 + 4 15 6 109 = 106 + 3 15 7 112 = 109 + 3
답의 차이로는 마땅한 규칙을 찾을 수 없다. 정답일때 문자열을 함께 출력해보기로 했다.
#include <iostream> #include <string> using namespace std; pair<int,string> go(int n, int m, string s) { if (m < 0) { return make_pair(0,""); } if (n == 0) { int cnt = 0; for (int i=0; i<s.length(); i++) { bool ok = false; for (int j=i; j<s.length(); j++) { if (s[j] == '1') ok = true; if (ok) cnt += 1; } } return make_pair(cnt,s); } auto t1 = go(n-1, m, s+"0"); auto t2 = go(n-1, m-1, s+"1"); return max(t1,t2); } int main() { for (int n=8; n<=15; n++) { for (int m=1; m<=n/2; m++) { auto p = go(n, m, ""); cout << n << ' ' << m << ' ' << p.first << '\n'; cout << p.second << '\n'; } cout << '\n'; } return 0; }
13 ≤ N ≤ 15일때를 보니 다음과 같다.
13 1 49 0000001000000 13 2 65 0001000010000 13 3 73 0010010001000 13 4 78 0100100100100 13 5 81 0101010100100 13 6 84 0101010101010 14 1 56 00000010000000 14 2 75 00001000010000 14 3 84 00100010001000 14 4 90 00100100100100 14 5 93 01010100100100 14 6 96 01010101010100 14 7 98 10101010101010 15 1 64 000000010000000 15 2 85 000010000100000 15 3 96 000100010001000 15 4 102 001001001001000 15 5 106 010100100100100 15 6 109 010101010100100 15 7 112 010101010101010
f(s)의 값이 최대가 될때 1의 위치가 규칙성이 있다.
N = 15일때를 보면
- M = 1:
000000010000000
, (0 7개) 1 (0 7개) - M = 2:
000010000100000
, (0 4개) 1 (0 4개) 1 (0 5개) - M = 3:
000100010001000
, (0 3개) 1 (0 3개) 1 (0 3개) 1 (0 3개) - M = 4:
001001001001000
, (0 2개) 1 (0 2개) 1 (0 2개) 1 (0 2개) 1 (0 3개) - …
문자열은 0과 1로만 이루어져 있고, 1이 0을 나누는 역할을 한다고 볼 수 있다. 이때 연속하는 0의 개수가 같거나 1 차이가 나야 한다.
1로인해서 나누어지는 0을 그룹이라고 했을 때, f(s)가 최대가 되는 경우는 0 N-M개를 M+1개의 그룹으로 나누는 경우와 같다고 볼 수 있다.
N-M이 M+1로 나누어 떨어지는 경우가 아니면 뒤쪽에 있는 그룹의 크기가 1 크면 된다.
위에서 구현한 브루트 포스 알고리즘에서 f(s)값을 구하는 것을 조금 개선해볼 수 있다.
int cnt = 0; for (int i=0; i<s.length(); i++) { bool ok = false; for (int j=i; j<s.length(); j++) { if (s[j] == '1') ok = true; if (ok) cnt += 1; } }
위의 소스는 s의 (i, j)에 1이 있으면 cnt
에 1을 증가시키는 소스이다. 한 번 1이 등장하면, 그 뒤의 위치에도 1이 등장한다고 볼 수 있다.
int cnt = 0; for (int i=0; i<s.length(); i++) { for (int j=i; j<s.length(); j++) { if (s[j] == '1') { cnt += s.length()-j; break; } } }
위의 계산 아이디어와 위에서 발견한 규칙을 합하면 문제를 해결할 수 있다.
N = 15, M = 2인 경우는
000010000100000
이 문제의 정답이다.
000010000100000
000010000100000
i가 빨간색에 있다며, j는 파란색에 있어야 한다. 여기서 5 * 11개를 찾을 수 있다.
000010000100000
000010000100000
여기에서 5 * 6개를 찾을 수 있고,
000010000100000
000010000100000
여기에서는 찾을 수가 없다. 5*11과 5*6을 더하면 85이고, 이 값은 우리가 위해서 구한 값과 같다.
규칙을 좀 더 명확하게 찾기 위해서 좀 더 긴 N으로 예시를 들어보자.
N = 20, M = 5인 경우
- 00100100100010001000
가 문제의 정답이고, f(s)의 값은 183이다. 위에서 계산한 것 처럼 다시 계산해보자.
00100100100010001000
00100100100010001000
3 * 18
00100100100010001000
00100100100010001000
3 * 15
00100100100010001000
00100100100010001000
3 * 12
00100100100010001000
00100100100010001000
4 * 8
00100100100010001000
00100100100010001000
4 * 4
모두 합해보면 3 * 18 + 3 * 15 + 3 * 12 + 4 * 8 + 4 * 4 = 183이고, 문제의 정답과 같다.
1로 나누어지는 0의 개수는 앞과 뒤로 나눌 수 있고, 앞과 뒤를 따로 계산해야 한다.
0은 N-M개가 있고, 이 0을 총 M+1개의 그룹으로 나누어야 한다. 아래 식에서 /는 정수 나눗셈이다.
- 앞쪽 0 그룹
- (N-M) / (M+1)개의 0이 있음
- 그룹의 개수는 (M+1) – (N-M) % (M+1)개
- 뒤쪽 0 그룹
- (N-M) / (M+1) + 1개의 0이 있음
- 그룹의 개수는 (N-M) % (M+1)개
앞쪽 0 그룹에서 구해지는 (i, j) 쌍을 모두 구해보면 3 * 18, 3 * 15, 3 * 12 이다. 등차수열로 볼 수 있다. 3은 항상 곱해지니, 3을 빼고 합을 계산한 다음, 합에 3을 곱하기로 했다.
식이 너무 길어질 수 있어 K = (N-M) / (M+1), R = (N-M) % (M+1)로 작성했다.
- 첫 항: N-K
- 공차: -(K+1)
- 항의 개수: (M+1) – R
이제 위의 등차수열의 합에서 (K+1)를 곱하면 앞쪽 0 그룹에서 만들 수 있는 (i, j) 쌍의 개수를 구할 수 있다.
뒤쪽 0 그룹도 비슷하게 구할 수 있다.
- 첫 항: (K+2) * (R-1)
- 공차: -(K+2)
- 항의 개수: R – 1
위의 등차수열의 합에서 (K+2)를 곱하면 뒤쪽 0 그룹에서 만들 수 있는 쌍을 모두 구할 수 있다.
#!/usr/bin/env python3 import sys def calc(a0, diff, n): return n * (2*a0 + (n-1)*diff) // 2 def go(n, m): if n//2 <= m: ans = n*(n+1)//2 ans -= (n-m) return ans elif m == 0: return 0 else: k = (n-m)//(m+1) r = (n-m)%(m+1) begin = (k+1) * calc(n-k, -(k+1), (m+1)-r) end = (k+2) * calc((k+2)*(r-1), -(k+2), r-1) return begin + end t = int(sys.stdin.readline()) for _ in range(t): n, m = map(int,sys.stdin.readline().split()) print(go(n, m))
begin
이 앞쪽 0 그룹, end
가 뒤쪽 0 그룹이다.
당연히 위 두 문제의 정해는 여기에 적혀있는 방법이 아니다.