抱歉,您的浏览器无法访问本站

本页面需要浏览器支持(启用)JavaScript


了解详情 >

题意

给定两个长度为\(n\)数组\(A,B\),下标范围\([0,n-1]\)

求所有整数\(k\in [0,n-1]\),满足存在一个\(m\)次多项式\(C\),使得对于所有\(i\in [0,n-1]\),都有\(C(i)\equiv A_i-B_{(i+k)\bmod n} \pmod{998244353}\)

题解

先附上PPT:

This is a picture without description

This is a picture without description

这完全看不懂啊

首先,\(B\)数组可以看成一个环,每次选环上的长度为\(n\)的一段。所以我们直接把B数组复制接在末尾,变为原来的两倍长度。

然后,我们来理解这个差分。

对于任意一个\(m\)次多项式\(f(x)\),当\(x\)分别取\(0,1,2,\cdots,n(n\ge m+1)\)时,将会得到\(n+1\)个点值。将这\(n+1\)个点值不断差分,\(m+1\)次后会都变成\(0\)

例如,当\(f(x)=x^3+x^2-2x+1\)时,分别取\(0,1,2,3,4,5\)代入,得到\(6\)个点值(第\(0\)行)。

1
2
3
4
5
6
x   0   1   2   3   4   5
0 1 1 9 31 73 141
1 0 8 22 42 68
2 8 14 20 26
3 6 6 6
4 0 0

然后可以发现,按表格中的排列,\(t\)次差分以后(第\(k\)行第\(i\)列)\(c_i=\sum\limits_{j=0}^i (-1)^j C_k^j f(i-j)\)

例如,表中第二行第四个数\(20=(73-31)-(31-9)=73-2\times 31+9=C_2^0\times 73-C_2^1\times 31+C_2^2\times 9\)

又可以发现,\(c_i=\sum\limits_{j=0}^i (-1)^j C_k^j f(i-j)\)是卷积的形式,所以直接用FFT/NTT优化。

这样,我们对\(A\)\(B\)分别做一遍卷积,然后把\(A[m+1,n-1]\)\(B[m+1,2n-1]\)做一次\(KMP\)即可。

由于数据较水,直接\(hash\)也能过。

注意特判\(m\ge n-1\)的情况,此时任意\(k\)都满足条件。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <cstdio>
#include <cctype>
#include <algorithm>
int read(){
register int x = 0;
register char ch = getchar();
for (; !isdigit(ch); ch = getchar()) ;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ '0');
return x;
}
#define N 1100005
#define P 998244353
int n, m, a[N], b[N], c[N], fac[N], inv[N], cnt, ans[N];
int qpow(int a, int b){
int s = 1;
for (; b; b >>= 1, a = 1ll * a * a % P) if (b & 1) s = 1ll * s * a % P;
return s;
}
void pre(int n){
fac[0] = 1;
for (register int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % P;
inv[n] = qpow(fac[n], P - 2);
for (register int i = n; i; --i) inv[i - 1] = 1ll * inv[i] * i % P;
}
int C(int n, int m){
return n < m ? 0 : 1ll * fac[n] * inv[m] % P * inv[n - m] % P;
}
namespace Polynomial{
int omega[N], rev[N];
void init(int n){
register int k = 0;
while ((1 << k) < n) ++k;
for (register int i = 0; i < n; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k >> 1;
}
void NTT(int n, int *a, int o = 1){
for (register int i = 0; i < n; ++i) if (i < rev[i]) std :: swap(a[i], a[rev[i]]);
for (register int m = 1, l; m < n; m = l){
l = m << 1, omega[0] = 1, omega[1] = qpow(~o ? 3 : 332748118, (P - 1) / l);
for (register int i = 2; i < m; ++i) omega[i] = 1ll * omega[i - 1] * omega[1] % P;
for (register int *p = a, t; p < a + n; p += l)
for (register int k = 0; k < m; ++k)
t = 1ll * omega[k] * p[m + k] % P, (p[m + k] = p[k] - t) < 0 ? p[m + k] += P : 0,
(p[k] += t) >= P ? p[k] -= P : 0;
}
if (o == -1)
for (register int i = 0, _n = qpow(n, P - 2); i < n; ++i) a[i] = 1ll * a[i] * _n % P;
}
void Trans(int nt, int *a, int *b, int *c){
int n = 1;
while (n < nt) n <<= 1;
init(n), NTT(n, a), NTT(n, b), NTT(n, c);
for (register int i = 0; i < n; ++i)
a[i] = 1ll * a[i] * c[i] % P, b[i] = 1ll * b[i] * c[i] % P;
NTT(n, a, -1), NTT(n, b, -1);
}
}
int fail[N];
int KMP(int n, int *S, int m, int *T){
fail[1] = 0;
for (register int i = 2, j = 0; i <= m; ++i){
while (j && T[j] != T[i - 1]) j = fail[j];
fail[i] = j += (T[j] == T[i - 1]);
}
for (register int i = 1, j = 0; i <= n; ++i){
while (j && T[j] != S[i - 1]) j = fail[j];
j += (T[j] == S[i - 1]);
if (j == m) ans[++cnt] = i - m, j = fail[j];
}
}
int main(){
freopen("pine.in", "r", stdin);
freopen("pine.out", "w", stdout);
n = read(), m = read() + 1;
if (m >= n){
printf("%d\n", n);
for (register int i = 0; i < n; ++i) printf("%d\n", i);
return 0;
}
pre(n);
for (register int i = 0; i < n; ++i) a[i] = read();
for (register int i = 0; i < n; ++i) b[i] = read(), b[i + n] = b[i];
for (register int i = 0; i < n; ++i) c[i] = i & 1 ? P - C(m, i) : C(m, i);
Polynomial :: Trans(2 * n, a, b, c);
// for (register int i = 0; i < n; ++i) printf("%d ", a[i]); putchar('\n');
// for (register int i = 0; i < n; ++i) printf("%d ", b[i]); putchar('\n');
KMP(2 * n - 1 - m, b + m, n - m, a + m);
printf("%d\n", cnt);
for (register int i = 1; i <= cnt; ++i) printf("%d\n", ans[i]);
}

评论