一般的容斥思路:枚举位置集合 { p i } \{p_i\} {pi}表示 S p i ∼ p i + 2 ∈ { A B C , B C A , C A B } S_{p_i\sim p_i+2}\in \{ABC,BCA,CAB\} Spi∼pi+2∈{ABC,BCA,CAB},然后算方案数。这个方法比较通用,但是在这道题中好像做不出来。
考虑上面的容斥,如果将字母看成是数字,本质上是限制了 a i = ( a i − 1 + 1 ) m o d 3 a_i=(a_{i-1}+1)\bmod 3 ai=(ai−1+1)mod3。因此原序列可以被划分成若干个极长的子段,设长度为 k k k的子段对应的容斥系数为 c k c_k ck,可以得到 c k = − c k − 1 − c k − 2 c_k=-c_{k-1}-c_{k-2} ck=−ck−1−ck−2,因此归纳得到:
c k = { − 1 k m o d 3 = 0 1 k m o d 3 = 1 0 k m o d 3 = 2 c_k=\begin{cases}-1& k\bmod3=0\\1&k\bmod3=1\\0&k\bmod3=2\end{cases} ck=⎩ ⎨ ⎧−110kmod3=0kmod3=1kmod3=2
然后可以 O ( n 2 ) O(n^2) O(n2)暴力求出答案。
这个时候组合数就没有什么优势了。考虑用三元 G F GF GF进行化简,反复利用: 1 1 − F ( x ) = ∑ i ≥ 0 F ( x ) i \frac{1}{1-F(x)}=\sum_{i\ge 0}F(x)^i 1−F(x)1=i≥0∑F(x)i
每一段的
G
F
GF
GF:
F
=
−
3
∑
i
≥
1
(
a
b
c
)
i
+
∑
i
≥
0
(
a
b
c
)
i
(
a
+
b
+
c
)
=
−
3
a
b
c
⋅
1
1
−
a
b
c
+
(
a
+
b
+
c
)
⋅
1
1
−
a
b
c
=
(
−
3
a
b
c
+
a
+
b
+
c
)
⋅
1
1
−
a
b
c
\begin{aligned}F&=-3\sum_{i\ge 1}(abc)^i+\sum_{i\ge 0}(abc)^i(a+b+c)\\&=-3abc\cdot \frac{1}{1-abc}+(a+b+c)\cdot \frac{1}{1-abc}\\&=(-3abc+a+b+c)\cdot \frac{1}{1-abc} \end{aligned}
F=−3i≥1∑(abc)i+i≥0∑(abc)i(a+b+c)=−3abc⋅1−abc1+(a+b+c)⋅1−abc1=(−3abc+a+b+c)⋅1−abc1
设答案的生成函数为 G G G。则:
G = ∑ i ≥ 0 F i = 1 1 − F = 1 − a b c 2 a b c − a − b − c + 1 = ( 1 − a b c ) ∑ i ≥ 0 ( a + b + c − 2 a b c ) \begin{aligned}G&=\sum_{i\ge 0}F^i\\&=\frac{1}{1-F}\\&=\frac{1-abc}{2abc-a-b-c+1}\\&=(1-abc)\sum_{i\ge 0}(a+b+c-2abc) \end{aligned} G=i≥0∑Fi=1−F1=2abc−a−b−c+11−abc=(1−abc)i≥0∑(a+b+c−2abc)
暴力枚举即可 O ( n ) O(n) O(n)计算答案。
remark \text{remark} remark 以后遇到这样将子段拼起来的问题可以考虑用 G F GF GF。这个东西 好像见过
#include
#define ll long long
#define fi first
#define se second
#define pb push_back
#define inf 0x3f3f3f3f
using namespace std;
const int N=3e6+5;
const int mod=998244353;
ll fac[N],inv[N],res;
int a,b,c;
ll fpow(ll x,ll y=mod-2){
ll z(1);
for(;y;y>>=1){
if(y&1)z=z*x%mod;
x=x*x%mod;
}return z;
}
void init(int n){
fac[0]=1;for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;
inv[n]=fpow(fac[n]);for(int i=n;i>=1;i--)inv[i-1]=inv[i]*i%mod;
}
ll binom(int x,int y){
if(x<0||y<0||x<y)return 0;
return fac[x]*inv[y]%mod*inv[x-y]%mod;
}
void add(ll &x,ll y){
x=(x+y)%mod;
}
ll calc(int a,int b,int c){
if(a<0||b<0||c<0)return 0;
ll res=0;
for(int i=0;i<=min({a,b,c});i++){
add(res,fpow(-2,i)*fac[a+b+c-2*i]%mod*inv[a-i]%mod*inv[b-i]%mod*inv[c-i]%mod*inv[i]);
}return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>a>>b>>c,init(a+b+c);
add(res,calc(a,b,c)-calc(a-1,b-1,c-1));
cout<<(res+mod)%mod;
}