想必你一定会用线段树维护等差数列吧?让我们来看看它的升级版。
请你维护一个长度为5×10^5的数组,一开始数组中每个元素都为0,要求支持以下两个操作:
1、区间[l,r]加自然数的平方数组,即al+=1,al+1+=4,al+2+=9,al+3+=16...ar+=(r−l+1)∗(r−l+1)
2、区间[l,r]查询区间和mod 10^9+7
第一行输入n,m,(1≤n,m≤5×105)n,m。 接下来m行,对于每行,先读入一个整数q。 当q的值为1时,还需读入两个整l,r,(1≤l≤r≤n)表示需要对区间[l,r]进行操作,让第一个元素加1,第二个元素加4,第三个元素加9以此类推。 当q的值为2时,还需读入两个整数l,r(1≤l≤r≤n)表示查询l到r的元素和
对于每一个q=2,输出一行一个非负整数,表示l到r的区间和mod 10^9+7。
示例1
4 4 2 1 4 1 1 4 1 3 4 2 1 4
0 35
示例2
10 6 1 1 6 1 8 9 1 3 6 2 1 10 1 1 10 2 1 10
126 511
[l,r]添加平方数列
对于任意位置x属于[l,r]
增加的值应当是( x − ( l − 1 ) ) ^2
展开 :x^2 - 2(l-1)x +(l-1)^2
维护6个系数,分开来求
- #include<bits/stdc++.h>
- using namespace std;
- #define int long long
- const int N=5e5+10;
- const int mod=1e9+7;
- int n,m;
- struct node
- {
- int sum0,sum1,sum2,lz0,lz1,lz2;
- } t[N*4];
- void pushdown(int i,int l,int r)
- {
- if(t[i].lz0)
- {
- int k=t[i].lz0;
- int mid=(l+r)>>1;
- t[i<<1].sum0+=k*(mid-l+1)%mod;
- t[i<<1].sum0%=mod;
- t[i<<1|1].sum0+=k*(r-mid)%mod;
- t[i<<1|1].sum0%=mod;
- t[i<<1].lz0+=k;
- t[i<<1].lz0%=mod;
- t[i<<1|1].lz0+=k;
- t[i<<1|1].lz0%=mod;
- t[i].lz0=0;
- }
- if(t[i].lz1)
- {
- int k=t[i].lz1;
- int mid=(l+r)>>1;
- t[i<<1].sum1+=k*((mid+l)*(mid-l+1)/2%mod)%mod;
- t[i<<1].sum1%=mod;
- t[i<<1|1].sum1+=k*((r+mid+1)*(r-mid)/2%mod)%mod;
- t[i<<1|1].sum1%=mod;
- t[i<<1].lz1+=k;
- t[i<<1].lz1%=mod;
- t[i<<1|1].lz1+=k;
- t[i<<1|1].lz1%=mod;
- t[i].lz1=0;
- }
- if(t[i].lz2)
- {
- int k=t[i].lz2;
- int mid=(l+r)>>1;
- t[i<<1].sum2+=k*((mid*(mid+1)/2*(2*mid+1)/3%mod)-((l-1)*((l-1)+1)/2*(2*(l-1)+1)/3%mod)+mod)%mod%mod;
- t[i<<1].sum2%=mod;
- t[i<<1|1].sum2+=k*((r*(r+1)/2*(2*r+1)/3%mod)-((mid+1-1)*((mid+1-1)+1)/2*(2*(mid+1-1)+1)/3%mod)+mod)%mod%mod;
- t[i<<1|1].sum2%=mod;
- t[i<<1].lz2+=k;
- t[i<<1].lz2%=mod;
- t[i<<1|1].lz2+=k;
- t[i<<1|1].lz2%=mod;
- t[i].lz2=0;
- }
- }
- void build(int i,int l,int r)
- {
- t[i].lz0=t[i].lz1=t[i].lz2=0;
- if(l==r)
- {
- t[i].sum0=t[i].sum1=t[i].sum2=0;
- return ;
- }
- int mid=(l+r)>>1;
- build(i<<1,l,mid);
- build(i<<1|1,mid+1,r);
- //pushup(rt);
- }
- void update0(int rt,int l,int r,int L,int R,int k)
- {
- if(L<=l&&r<=R)
- {
- t[rt].sum0+=k*(r-l+1)%mod;
- t[rt].sum0%=mod;
- t[rt].lz0+=k;
- t[rt].lz0%=mod;
- return ;
- }
- pushdown(rt,l,r);
- int mid=(l+r)>>1;
- if(L<=mid) update0(rt<<1,l,mid,L,R,k);
- if(R>mid)update0(rt<<1|1,mid+1,r,L,R,k);
- t[rt].sum0=(t[rt<<1].sum0+t[rt<<1|1].sum0)%mod;
- return;
- }
- void update1(int rt,int l,int r,int L,int R,int k)
- {
- if(L<=l&&r<=R)
- {
- t[rt].sum1+=k*((r+l)*(r-l+1)/2%mod)%mod;
- t[rt].sum1%=mod;
- t[rt].lz1+=k;
- t[rt].lz1%=mod;
- return;
- }
- pushdown(rt,l,r);
- int mid=(l+r)>>1;
- if(L<=mid) update1(rt<<1,l,mid,L,R,k);
- if(R>mid)update1(rt<<1|1,mid+1,r,L,R,k);
- t[rt].sum1=(t[rt<<1].sum1+t[rt<<1|1].sum1)%mod;
- return;
- }
- void update2(int rt,int l,int r,int L,int R,int k)
- {
- if(L<=l&&r<=R)
- {
- t[rt].sum2+=k*((r*(r+1)/2*(2*r+1)/3%mod)-((l-1)*((l-1)+1)/2*(2*(l-1)+1)/3%mod)+mod)%mod%mod;
- t[rt].sum2%=mod;
- t[rt].lz2+=k;
- t[rt].lz2%=mod;
- return;
-
- }
- pushdown(rt,l,r);
- int mid=(l+r)>>1;
- if(L<=mid) update2(rt<<1,l,mid,L,R,k);
- if(R>mid)update2(rt<<1|1,mid+1,r,L,R,k);
- t[rt].sum2=(t[rt<<1].sum2+t[rt<<1|1].sum2)%mod;
- return;
- }
- int query0(int rt,int l,int r,int L,int R)
- {
- if(L<=l&&R>=r)
- {
- return t[rt].sum0;
- }
- pushdown(rt,l,r);
- int mid=(l+r)>>1;
- int ans=0;
- if(mid >= R) ans=ans+query0(rt<<1, l, mid, L, R),ans%=mod;
- else if(mid < L) ans=ans+query0(rt<<1|1, mid + 1, r, L, R),ans%=mod;
- else
- {
- ans=ans+query0(rt<<1, l, mid, L, mid)+query0(rt<<1|1, mid + 1, r, mid + 1, R),ans%=mod;
- }
- return ans;
- }
- int query1(int rt,int l,int r,int L,int R)
- {
- if(L<=l&&R>=r)
- {
- return t[rt].sum1;
- }
- pushdown(rt,l,r);
- int mid=(l+r)>>1;
- int ans=0;
- if(mid >= R) ans=ans+query1(rt<<1, l, mid, L, R),ans%=mod;
- else if(mid < L) ans=ans+query1(rt<<1|1, mid + 1, r, L, R),ans%=mod;
- else
- {
- ans=ans+query1(rt<<1, l, mid, L, mid)+query1(rt<<1|1, mid + 1, r, mid + 1, R),ans%=mod;
- }
- return ans;
- }
- int query2(int rt,int l,int r,int L,int R)
- {
- if(L<=l&&R>=r)
- {
- return t[rt].sum2;
- }
- pushdown(rt,l,r);
- int mid=(l+r)>>1;
- int ans=0;
- if(mid >= R) ans=ans+query2(rt<<1, l, mid, L, R),ans%=mod;
- else if(mid < L) ans=ans+query2(rt<<1|1, mid + 1, r, L, R),ans%=mod;
- else
- {
- ans=ans+query2(rt<<1, l, mid, L, mid)+query2(rt<<1|1, mid + 1, r, mid + 1, R),ans%=mod;
- }
- return ans;
- }
- signed main()
- {
-
- cin>>n>>m;
- build(1,1,n);
- for(int i=1; i<=m; i++)
- {
- int op;
- cin>>op;
- if(op==1)
- {
- int l,r;
- cin>>l>>r;
- update0(1,1,n,l,r,(l-1)*(l-1)%mod);
- update1(1,1,n,l,r,(-2*(l-1)%mod+mod)%mod);
- update2(1,1,n,l,r,1);
- }
- else
- {
- int l,r;
- cin>>l>>r;
- cout<<(query0(1,1,n,l,r)%mod+query1(1,1,n,l,r)%mod+query2(1,1,n,l,r)%mod)%mod<<"\n";
- }
- }
- return 0;
- }