给你一个排列 p p p,对于每一个 i i i,我们在平面上,放置一个点 ( i , p i ) (i,p_i) (i,pi)。对于坐标上下限都在 1 ∼ n 1\sim n 1∼n内的全体 ( n ( n + 1 ) 2 ) 2 (\frac{n(n+1)}{2})^2 (2n(n+1))2矩形,求每个矩形内部点数的 k k k次方之和。
形式化地,请你计算
∑ 1 ≤ l ≤ r ≤ n ∑ 1 ≤ d ≤ u ≤ n ∣ { i ∣ l ≤ i ≤ r ∨ d ≤ p i ≤ u } ∣ \sum\limits_{1\leq l\leq r\leq n}\sum\limits_{1\leq d\leq u\leq n}|\{i|l\leq i\leq r\vee d\leq p_i\leq u\}| 1≤l≤r≤n∑1≤d≤u≤n∑∣{i∣l≤i≤r∨d≤pi≤u}∣
1 ≤ n ≤ 1 0 5 , 1 ≤ k ≤ 3 1\leq n\leq 10^5,1\leq k\leq 3 1≤n≤105,1≤k≤3
我们可以考虑拆贡献,点数的 k k k次方可以看成选 k k k个点的方案的线性组合。
什么意思呢?就是在这 n n n个点中有序地可重地选择 k k k个点,将所有包含这 k k k个点的矩形的贡献 + 1 +1 +1,注意所有从 n n n个点中有序地可重地选 k k k个点的方案都要被计算贡献。
为什么可以这样呢?对于每个矩形,设这个矩形内的点数为 t t t,在这个矩形中有序地可重地选 k k k个点的方案数为 t k t^k tk,也就是说这个矩形在上面计算贡献的时候将贡献加了 t k t^k tk次一。
下面,我们来求 k k k为不同的值时的答案。
对每个点 ( x , p x ) (x,p_x) (x,px),答案的贡献增加 x × ( n − x + 1 ) × p x × ( n − p x + 1 ) x\times (n-x+1)\times p_x\times (n-p_x+1) x×(n−x+1)×px×(n−px+1)。
我们考虑选的两个点相同的情况和两个点不同的情况。
对于两个点相同的情况,这其实就是 k = 1 k=1 k=1的情况,每种情况会被算一次。
对于两个点不同的情况,我们可以分为顺序对和逆序对来考虑:
因为选点是有序的,每种顺序对和逆序对都用两种选法被选到,所以两个点不同的情况的贡献要乘 2 2 2。
将 k = 1 k=1 k=1的贡献计算一次(三次选择同一个点), k = 2 k=2 k=2的贡献计算两次(三次选择两个不同的点),下面再考虑三次选择三个不同的点的贡献。
分为两种本质不同的情况:
------
|* |
| * |
| * |
------
这种情况出现了 2 2 2次(按 i i i左右翻转,总共有 2 2 2次),用两个树状数组维护即可。
------
| * |
|* |
| * |
------
这种情况总共出现了 4 4 4次(按 i i i左右翻转,按 p i p_i pi上下翻转,四个角度各一次,总共有 4 4 4次),用线段树来维护即可。可以在加入第一个点时直接在对应位置上加数,在加入第二个点时将其后缀乘上对应的数,再加入第三个点时查询前缀和。
因为选点是有序的,每种顺序对和逆序对都用六种选法被选到,所以两个点不同的情况的贡献要乘 6 6 6。
时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
#include
#define lc k<<1
#define rc k<<1|1
using namespace std;
const long long mod=998244353;
int n,K,p[100005];
long long tr1[100005],tr2[100005];
long long s[500005],hv[500005],ly[500005];
int lb(int i){
return i&(-i);
}
void pt1(int i,long long v){
while(i<=n){
tr1[i]=(tr1[i]+v)%mod;
i+=lb(i);
}
}
long long find1(int i){
long long re=0;
while(i){
re=(re+tr1[i])%mod;
i-=lb(i);
}
return re;
}
void pt2(int i,long long v){
while(i<=n){
tr2[i]=(tr2[i]+v)%mod;
i+=lb(i);
}
}
long long find2(int i){
long long re=0;
while(i){
re=(re+tr2[i])%mod;
i-=lb(i);
}
return re;
}
void build(int k,int l,int r){
s[k]=hv[k]=ly[k]=0;
if(l==r) return;
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
}
void down(int k){
s[lc]=(s[lc]+hv[lc]*ly[k])%mod;
s[rc]=(s[rc]+hv[rc]*ly[k])%mod;
ly[lc]=(ly[lc]+ly[k])%mod;
ly[rc]=(ly[rc]+ly[k])%mod;
ly[k]=0;
}
void ch(int k,int l,int r,int x,long long y){
if(l==r&&l==x){
hv[k]=y;
s[k]=ly[k]=0;
return;
}
if(ly[k]) down(k);
int mid=l+r>>1;
if(x<=mid) ch(lc,l,mid,x,y);
else ch(rc,mid+1,r,x,y);
hv[k]=(hv[lc]+hv[rc])%mod;
s[k]=(s[lc]+s[rc])%mod;
}
void ts(int k,int l,int r,int x,int y,long long v){
if(l>=x&&r<=y){
ly[k]=(ly[k]+v)%mod;
s[k]=(s[k]+v*hv[k])%mod;
return;
}
if(ly[k]) down(k);
int mid=l+r>>1;
if(x<=mid) ts(lc,l,mid,x,y,v);
if(y>mid) ts(rc,mid+1,r,x,y,v);
s[k]=(s[lc]+s[rc])%mod;
}
long long find(int k,int l,int r,int x,int y){
if(l>=x&&r<=y) return s[k];
if(ly[k]) down(k);
int mid=l+r>>1;
long long re=0;
if(x<=mid) re=(re+find(lc,l,mid,x,y))%mod;
if(y>mid) re=(re+find(rc,mid+1,r,x,y))%mod;
return re;
}
long long gt(){
long long re=0;
build(1,1,n);
for(int i=1;i<=n;i++){
re=(re+find(1,1,n,1,p[i])*(n-i+1)%mod*(n-p[i]+1)%mod)%mod;
ts(1,1,n,p[i],n,p[i]);
ch(1,1,n,p[i],i);
}
return re;
}
long long gt1(){
long long re=0;
for(int i=1;i<=n;i++){
re=(re+1ll*i*(n-i+1)%mod*p[i]%mod*(n-p[i]+1)%mod)%mod;
}
return re;
}
long long gt2(){
long long re=0;
for(int i=1;i<=n;i++){
re=(re+find1(p[i])*(n-i+1)%mod*(n-p[i]+1)%mod)%mod;
pt1(p[i],1ll*i*p[i]%mod);
}
for(int i=1;i<=n;i++){
tr1[i]=0;
if(i<n-i+1) swap(p[i],p[n-i+1]);
}
for(int i=1;i<=n;i++){
re=(re+find1(p[i])*(n-i+1)%mod*(n-p[i]+1)%mod)%mod;
pt1(p[i],1ll*i*p[i]%mod);
}
for(int i=1;i<=n;i++){
tr1[i]=0;
if(i<n-i+1) swap(p[i],p[n-i+1]);
}
return re;
}
long long gt3(){
long long re=0;
for(int i=1;i<=n;i++){
long long now=find1(p[i]);
pt1(p[i],1ll*i*p[i]%mod);
re=(re+find2(p[i])*(n-i+1)%mod*(n-p[i]+1)%mod)%mod;
pt2(p[i],now);
}
for(int i=1;i<=n;i++){
tr1[i]=tr2[i]=0;
if(i<n-i+1) swap(p[i],p[n-i+1]);
}
for(int i=1;i<=n;i++){
long long now=find1(p[i]);
pt1(p[i],1ll*i*p[i]%mod);
re=(re+find2(p[i])*(n-i+1)%mod*(n-p[i]+1)%mod)%mod;
pt2(p[i],now);
}
for(int i=1;i<=n;i++){
tr1[i]=tr2[i]=0;
if(i<n-i+1) swap(p[i],p[n-i+1]);
}
re=(re+gt())%mod;
for(int i=1;i<=n;i++)
if(i<n-i+1) swap(p[i],p[n-i+1]);
re=(re+gt())%mod;
for(int i=1;i<=n;i++)
p[i]=n-p[i]+1;
re=(re+gt())%mod;
for(int i=1;i<=n;i++)
if(i<n-i+1) swap(p[i],p[n-i+1]);
re=(re+gt())%mod;
return re;
}
int main()
{
freopen("points.in","r",stdin);
freopen("points.out","w",stdout);
scanf("%d%d",&n,&K);
for(int i=1;i<=n;i++){
scanf("%d",&p[i]);
}
if(K==1) printf("%lld",gt1());
else if(K==2) printf("%lld",(gt1()+2*gt2())%mod);
else{
printf("%lld",(gt1()+6*gt2()+6*gt3())%mod);
}
return 0;
}