给定一个长度为n的数组a,1<=ai<=n
初始时每个元素都独自成一个集合。
{a1}, {a2},…,{an}
执行以下操作
执行上述若干操作后,将上述集合的元素数量,统计到 一个可重复集合 M中。
问,M的数量有多少个。对 998244353取模
对于多重集合M,它最多有n个元素。不失一般性,我们假设它有n个元素(实际不足n个元素,可以用0填充){p1,p2,…,pn},且p1>=p2>=…>=pn>=0
结论
集合{p1,p2,…,pn}有效,当前仅当p1+p2+…+pn=n,
∑
i
=
1
k
p
i
<
=
∑
i
=
1
n
m
i
n
(
k
,
c
n
t
i
)
,
1
<
=
k
<
=
n
\sum_{i=1}^kp_i<=\sum_{i=1}^nmin(k,cnt_i),1<=k<=n
∑i=1kpi<=∑i=1nmin(k,cnti),1<=k<=n
第一点很好理解,因为一开始有n个元素,最终的多重集合,元素总和也是n。下面主要证明第二点,这里边和抽屉原理有关。
充分性证明:
已知集合{p1,p2,…,pn}有效。对于任意的k,由于此时有k个集合,那么对于1<=i<=n,最多能贡献次数为cnt[i],但不能超过k个,同一个元素,只能在同一个集合中出现一次。因此,此时i最多贡献min(cnt[i],k)次。这k个集合的大小总和,不能超过所有元素的最多贡献次数。所以
∑
i
=
1
k
p
i
<
=
∑
i
=
1
n
m
i
n
(
k
,
c
n
t
i
)
,
1
<
=
k
<
=
n
\sum_{i=1}^kp_i<=\sum_{i=1}^nmin(k,cnt_i),1<=k<=n
∑i=1kpi<=∑i=1nmin(k,cnti),1<=k<=n
必要性证明:
对于任意的k,同一个元素,只能在同一个集合中出现一次,此时每个元素能贡献的次数为min(cnt[i],k),所有元素能贡献的总和为
∑
i
=
1
n
m
i
n
(
k
,
c
n
t
i
)
\sum_{i=1}^nmin(k,cnt_i)
∑i=1nmin(k,cnti),又因为
∑
i
=
1
k
p
i
<
=
∑
i
=
1
n
m
i
n
(
k
,
c
n
t
i
)
\sum_{i=1}^kp_i<=\sum_{i=1}^nmin(k,cnt_i)
∑i=1kpi<=∑i=1nmin(k,cnti),且
∑
i
=
1
k
p
i
\sum_{i=1}^kp_i
∑i=1kpi是所有大小为k的p集合里边最大的(因为调的是前k大),那么说明 所有元素能贡献的次数 >= 任意一个大小为k的p集合。从而说明p集合是有效集合。
(好吧,数学太菜,证明得有点勉强)
根据上述结论,我们定义
f[pos][sum][last],
我们要从f[pos][sum][last]推f[pos+1],此时,
所以对于f[pos][sum][last] ,它可以转移到f[pos+1][sum+x][x],0<=x<=last
即f[pos][sum][last] += f[pos-1][sum-last][p], sum-last>= p>=last,这个计算过程可以用前缀和优化。
总体复杂度n^3,又因为last=p[pos]<=p[pos-1]<=…<=p1,所以last * pos <= n,所以last <= n / pos
复杂度降为
n
∗
(
n
/
1
+
n
/
2
+
n
/
3
+
.
.
.
+
n
/
n
)
=
n
∗
n
∗
l
o
g
n
n * (n/1+n/2+n/3+...+n/n) = n*n*logn
n∗(n/1+n/2+n/3+...+n/n)=n∗n∗logn
#include
using namespace std;
#define ll long long
#define pcc pair<char, char>
#define pii pair<int, int>
#define inf 0x3f3f3f3f
const int maxn = 2010;
const int mod = 998244353;
int n, x;
// cnt[i] denotes count of i
int cnt[maxn];
// f[pos][sum][last] denotes number of
// {p[1], p[2], ... , p[pos]} which meets as follow:
// 1) p[1] + p[2] + ... + p[pos] == sum,
// 2) p[pos] == last,
// 3) p[1] >= p[2] >= ... >= p[pos]
int f[2][maxn][maxn];
// g[sum][last] =
// f[pre][sum][0] + f[pre][sum][1] + ... + f[pre][sum][last]
int g[maxn][maxn];
// limit[i] denotes max of {p[1] + p[2] + ... + p[i]}
int limit[maxn];
void Add(int &x, int y) {
x += y;
x %= mod;
}
int Sub(int x, int y) {
x = (x + (mod - y)) % mod;
return x;
}
void init_limit() {
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= n; ++j) {
limit[i] += min(i, cnt[j]);
}
}
}
void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &x);
++cnt[x];
}
init_limit();
// dp init
for (int s = 0; s <= limit[1]; ++s) {
f[1][s][s] = 1;
}
for (int s = 0; s <= n; ++s) {
g[s][0] = f[1][s][0];
for (int lst = 1; lst <= n; ++lst) {
g[s][lst] = (g[s][lst-1] + f[1][s][lst]) % mod;
}
}
int ans = 0;
Add(ans, f[1][n][n]);
int cur = 0, pre, l, r, tmp;
for (int i = 2; i <= n; ++i, cur = pre) {
pre = 1 - cur;
for (int s = 1; s <= limit[i]; ++s) {
for (int lst = 0; lst <= min(n / i, s); ++lst) {
l = lst;
r = min(s - lst, n / (i - 1));
if (r >= l) {
tmp = Sub(g[s-lst][r], g[s-lst][l-1]);
Add(f[cur][s][lst], tmp);
}
// for (int p = lst; p <= min(s - lst, n / (i - 1)); ++p) {
// Add(f[cur][s][lst], f[pre][s-lst][p]);
// }
}
}
// update ans
for (int lst = 1; lst <= n; ++ lst) {
Add(ans, f[cur][n][lst]);
}
// update pre sum g[][]
for (int s = 0; s <= n; ++s) {
g[s][0] = f[cur][s][0];
for (int lst = 1; lst <= n / i; ++lst) {
g[s][lst] = (g[s][lst-1] + f[cur][s][lst]) % mod;
}
}
// clear f[pre]
for (int s = 0; s <= n; ++s) {
for (int lst = 0; lst <= n / (i - 1); ++lst) {
f[pre][s][lst] = 0;
}
}
}
printf("%d\n", ans);
}
int main() {
int t = 1;
// scanf("%d", &t);
int cas = 1;
while (t--) {
// printf("cas %d:\n", cas++);
solve();
}
}
漫神生快^_^