抱歉,您的浏览器无法访问本站

本页面需要浏览器支持(启用)JavaScript


了解详情 >

题目传送门

题意

求满足以下条件的长度为 \(n\) 的非负整数序列 \(a_1,a_2,a_3,\cdots,a_n\) 的方案数 \(\bmod 10^9+7\) 的值:

  1. \(l\le \sum_{i=1}^n a_i\le r\)
  2. 将序列从大到小排序后,记为 \(a_1',a_2',a_3',\cdots,a_n'\),满足 \(a_m'=a_{m+1}'\)

\(1\le m < n\le 3\times 10^5,1\le l,r\le 3\times 10^5\)

题解

显然第一个条件可以差分,假设总和的上界为 \(S\),用隔板法得到不考虑第二个条件的方案数为 \(C_{S+n}^n\)。直接算满足第二个条件的不好算,我们考虑算不满足第二个条件的方案数,然后用总方案数减去即可。

我们枚举从大到小排序后的第 \(m\) 个数为 \(x\),那么不满足条件的方案数等于在 \(n\) 个位置中选出 \(m\) 个位置使得这 \(m\) 个位置的值 \(\ge x\),其余 \(n-m\) 个位置的值 \(< x\)\(n\) 个数总和 \(\le S\) 的方案数,减去在 \(n\) 个位置中选出 \(m\) 个位置使得这 \(m\) 个位置的值 \(\ge x+1\),其余 \(n-m\) 个位置的值 \(< x\)\(n\) 个数总和 \(\le S\) 的方案数。

于是问题变成了求在 \(n\) 个位置中选出 \(m\) 个位置使得这 \(m\) 个位置的值 \(\ge a\),其余 \(n-m\) 个位置的值 \(< b\)\(n\) 个数总和 \(\le S\) 的方案数。

发现隔板法可以解决的问题的条件是形如 \(a_i\ge lim_i\) 这样的,于是我们把 \(< b\) 的部分容斥成这个形式。于是我们强制 \(i\) 个位置 \(\ge b\),其他位置随便选。

问题变成了求在 \(n\) 个位置中选出 \(m\) 个位置使得这 \(m\) 个位置的值 \(\ge a\),在剩余 \(n-m\) 个位置中选 \(i\) 个,使得这 \(i\) 个位置的值 \(\ge b\),其他 \(n-m-i\) 个位置的值 \(\ge 0\),总和 \(\le S\) 的方案数。这个问题的答案就是 \(C_{n-m}^i\times C_{S-ma-ib+n}^n\)。再乘上容斥系数 \((-1)^i\) 即可。

由于需要保证 \(ix\le S\),复杂度是一个调和级数的形式,所以复杂度是 \(O(S\log S)\)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
int read(){
register int x = 0;
register char f = 1, ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f ^= 1;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ '0');
return f ? x : -x;
}
#define N 600005
#define P 1000000007
int n, m, l, r, fac[N], inv[N];
void inc(int &a, int b){ (a += b) >= P ? a -= P : 0; }
void dec(int &a, int b){ (a -= b) < 0 ? a += P : 0; }
int plus(int a, int b){ return (a += b) >= P ? a - P : a; }
int minus(int a, int b){ return (a -= b) < 0 ? a + P : a; }
int qpow(int a, int b = P - 2){
int s = 1;
for (; b; b >>= 1, a = 1ll * a * a % P) if (b & 1) s = 1ll * s * a % P;
return s;
}
void init(int n){
fac[0] = 1;
for (register int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % P;
inv[n] = qpow(fac[n]);
for (register int i = n; i; --i) inv[i - 1] = 1ll * inv[i] * i % P;
}
int C(int n, int m){
if (n < m) return 0;
return 1ll * fac[n] * inv[m] % P * inv[n - m] % P;
}
int get(int S, int r){
if (S < 0) return 0;
int res = 0;
for (register int i = 0, t = 0; i <= n - m && t <= S; t += r, ++i){
int s = 1ll * C(n - m, i) * C(S - t + n, n) % P;
if (i & 1) dec(res, s); else inc(res, s);
}
return 1ll * res * C(n, m) % P;
}
int solve(int S){
int res = C(S + n, n);
for (register int i = 1; i * m <= S; ++i)
dec(res, minus(get(S - i * m, i), get(S - (i + 1) * m, i)));
return res;
}
int main(){
n = read(), m = read(), l = read(), r = read();
init(n + r);
printf("%d\n", minus(solve(r), solve(l - 1)));
}

评论