抱歉,您的浏览器无法访问本站

本页面需要浏览器支持(启用)JavaScript


了解详情 >

LOJ 6059

题解

一眼 DP。并且很显然可以倍增优化。\(dp_{i,j,k}\) 表示 \(2^i\) 位,数字之和恰好\(j\),模 \(p\)\(k\) 时的方案数。有如下转移:

\[dp_{i,j,k}=\sum_{j_1+j_2=j}\sum_{k_1\times 10^{2^{i-1}}+k_2\equiv k\pmod p}dp_{i-1,j_1,k_1}\times dp_{i-1,j_2,k_2}\]

直接这样做是 \(O(p^2m^2\log n)\) 的。

只考虑 \(j\) 这一维,很显然可以用 FFT 优化,于是复杂度变为 \(O(p^2m\log m\log n)\),但是常数较大,比较难卡过。

如果我们记

\[tmp_{j,k}=\sum_{t\times 10^{2^{i-1}}\equiv k\pmod p} dp_{i-1,j,t}\]

则原式变成了

\[dp_{i,j,k}=\sum_{j_1+j_2=j}\sum_{(k_1+k_2)\bmod p=k}tmp_{j_1,k_1}\times dp_{i-1,j_2,k_2}\]

然而这个 \(\bmod\) 还是不太舒服,我们考虑把第二维值域扩充到 \([0,2p)\),然后再把 \([p,2p)\) 这部分加到 \([0,p)\)。则式子变得十分美观:

\[dp_{i,j,k}=\sum_{j_1+j_2=j}\sum_{k_1+k_2=k}tmp_{j_1,k_1}\times dp_{i-1,j_2,k_2}\]

我们把 \(j,k\) 两维拍到一起,发现这个式子仍然是个卷积的形式。于是直接用 FFT 优化。

时间复杂度 \(O(mp\log (mp)\log n)\)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <cstdio>
#include <algorithm>
#define N 70005
#define P 998244353
int n, p, m, pw[35], dp[35][N], s[N], tmp[N], len;
void upd(int &x, int y){
(x += y) >= P ? x -= P : 0;
}
int add(int x, int y){
return (x += y) >= P ? x - P : x;
}
int del(int x, int y){
return (x -= y) < 0 ? x + P : x;
}
int qpow(int a, int b = P - 2, int p = P){
int s = 1;
for (; b; b >>= 1, a = 1ll * a * a % p) if (b & 1) s = 1ll * s * a % p;
return s;
}
struct Number_Theory_Transform{
int n, rev[N], omega[N];
void init(int m){
n = m;
register int k = 0;
while ((1 << k) < n) ++k;
for (register int i = 1; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k >> 1;
}
void NTT(int* a, int o){
for (register int i = 0; i < n; ++i) i < rev[i] ? std :: swap(a[i], a[rev[i]]), 0 : 0;
for (register int m = 1; m < n; m <<= 1){
register int l = m << 1, omega1 = qpow(o == 1 ? 3 : 332748118, (P - 1) / l);
omega[0] = 1;
for (register int i = 1; i < m; ++i) omega[i] = 1ll * omega[i - 1] * omega1 % P;
for (register int* p = a; p < a + n; p += l)
for (register int i = 0; i < m; ++i){
register int t = 1ll * omega[i] * p[m + i] % P;
p[m + i] = del(p[i], t), upd(p[i], t);
}
}
if (o == -1){
register int _n = qpow(n);
for (register int i = 0; i < n; ++i) a[i] = 1ll * a[i] * _n % P;
}
}
}T;
void multiply(int na, int *A, int nb, int *B, int *C){
int n;
if (na <= 20 && nb <= 20){
n = std :: max(na, nb);
for (register int i = 0; i < n; ++i) C[i] = 0;
for (register int i = 0; i < na; ++i)
for (register int j = 0; j < nb; ++j)
if (i + j < n) upd(C[i + j], 1ll * A[i] * B[j] % P);
return;
}
int a[N], b[N];
n = 1;
while (n < na + nb - 1) n <<= 1;
for (register int i = 0; i < na; ++i) a[i] = A[i];
for (register int i = 0; i < nb; ++i) b[i] = B[i];
for (register int i = na; i < n; ++i) a[i] = 0;
for (register int i = nb; i < n; ++i) b[i] = 0;
T.init(n), T.NTT(a, 1), T.NTT(b, 1);
for (register int i = 0; i < n; ++i) a[i] = 1ll * a[i] * b[i] % P;
T.NTT(a, -1);
for (register int i = 0; i < std :: max(na, nb); ++i) C[i] = a[i];
}
int main(){
scanf("%d%d%d", &n, &p, &m);
len = (m + 1) * p * 2;
for (register int i = 0; i <= 9 && i <= m; ++i) ++dp[0][i * p * 2 + i % p];
pw[0] = 1;
for (register int i = 1; i <= 30; ++i) pw[i] = pw[i - 1] << 1;
for (register int i = 0; i <= 30; ++i) pw[i] = qpow(10, pw[i], p);
for (register int i = 0; i < 30; ++i){
for (register int j = 0; j < len; ++j) tmp[j] = 0;
for (register int j = 0; j <= m; ++j)
for (register int k = 0; k < p; ++k)
upd(tmp[j * p * 2 + 1ll * k * pw[i] % p], dp[i][j * p * 2 + k]);
multiply(len, tmp, len, dp[i], dp[i + 1]);
for (register int j = 0; j <= m; ++j)
for (register int k = 0; k < p; ++k)
upd(dp[i + 1][j * p * 2 + k], dp[i + 1][j * p * 2 + k + p]), dp[i + 1][j * p * 2 + k + p] = 0;
}
s[0] = 1;
for (register int i = 30; ~i; --i)
if ((1 << i) <= n){
for (register int j = 0; j < len; ++j) tmp[j] = 0;
for (register int j = 0; j <= m; ++j)
for (register int k = 0; k < p; ++k)
upd(tmp[j * p * 2 + 1ll * k * pw[i] % p], s[j * p * 2 + k]), s[j * p * 2 + k] = 0;
multiply(len, tmp, len, dp[i], s);
for (register int j = 0; j <= m; ++j)
for (register int k = 0; k < p; ++k)
upd(s[j * p * 2 + k], s[j * p * 2 + k + p]), s[j * p * 2 + k + p] = 0;
n -= (1 << i);
}
int sum = 0;
for (register int i = 0; i <= m; ++i) upd(sum, s[i * p * 2]), printf("%d ", sum);
}

评论