简单的区间计数问题可能直接推式子就行了。
但有些问题必须要数据结构维护。线段树就是一个比较好的处理区间的数据结构。
满足条件的区间特征:
max
{
a
i
}
−
min
{
a
i
}
+
1
−
c
n
t
=
0
\max\{a_i\}-\min\{a_i\}+1-cnt=0
max{ai}−min{ai}+1−cnt=0,其中
c
n
t
cnt
cnt 代表区间内不同数字的个数。
考虑固定右端点,统计有多少个合法的左端点。
我们可以用线段树维护
m
i
n
v
=
min
{
max
{
a
i
}
−
min
{
a
i
}
−
c
n
t
}
minv=\min\{\max\{a_i\}-\min\{a_i\}-cnt\}
minv=min{max{ai}−min{ai}−cnt} 和
n
u
m
=
有多少个区间左端点可以取到
m
i
n
v
num=有多少个区间左端点可以取到 minv
num=有多少个区间左端点可以取到minv,答案就是
m
i
n
v
=
−
1
minv=-1
minv=−1 时的
n
u
m
num
num
max
{
a
i
}
\max\{a_i\}
max{ai} 和
min
{
a
i
}
\min\{a_i\}
min{ai} 可以用两个单调栈维护。
#include
#define int long long
using namespace std;
const int N=1e6+7,inf=1e18;
struct seg
{
int minv,tag,cnt;
seg()
{
minv=tag=cnt=0;
}
};
vector<seg> tr;
void update(int u)
{
tr[u].minv=min(tr[u<<1].minv,tr[u<<1|1].minv);
if(tr[u<<1].minv==tr[u<<1|1].minv)
{
tr[u].cnt=tr[u<<1].cnt+tr[u<<1|1].cnt;
}
else if(tr[u].minv==tr[u<<1].minv)
{
tr[u].cnt=tr[u<<1].cnt;
}
else if(tr[u].minv==tr[u<<1|1].minv)
{
tr[u].cnt=tr[u<<1|1].cnt;
}
else
{
assert(false);
}
}
void pushdown(int u)
{
if(tr[u].tag)
{
tr[u<<1].minv+=tr[u].tag; tr[u<<1|1].minv+=tr[u].tag;
tr[u<<1].tag+=tr[u].tag; tr[u<<1|1].tag+=tr[u].tag;
tr[u].tag=0;
}
}
void build(int u,int st,int ed)
{
if(st==ed)
{
tr[u].cnt=1;
return;
}
int mid=st+ed>>1;
build(u<<1,st,mid);
build(u<<1|1,mid+1,ed);
update(u);
}
void modify(int u,int st,int ed,int l,int r,int x)
{
if(l<=st&&ed<=r)
{
tr[u].minv+=x;
tr[u].tag+=x;
return;
}
pushdown(u);
int mid=st+ed>>1;
if(mid>=l)
modify(u<<1,st,mid,l,r,x);
if(mid<r)
modify(u<<1|1,mid+1,ed,l,r,x);
update(u);
}
int query(int u,int st,int ed,int l,int r)
{
if(l<=st&&ed<=r)
{
return tr[u].minv==-1?tr[u].cnt:0;
}
pushdown(u);
int mid=st+ed>>1;
int res=0;
if(mid>=l)
res=query(u<<1,st,mid,l,r);
if(mid<r)
res+=query(u<<1|1,mid+1,ed,l,r);
return res;
}
int O_o()
{
int n;
cin>>n;
tr.assign(n+1<<2,seg());
vector<int> a(n+1),ls(n+1);
map<int,int> mp;
for(int i=1; i<=n; i++)
{
cin>>a[i];
ls[i]=mp[a[i]];
mp[a[i]]=i;
}
build(1,1,n);
stack<array<int,2>> sx,sy;// decrease, increase
int ans=0;
for(int i=1; i<=n; i++)
{
int x=a[i];
while(sx.size()&&x>sx.top()[0])
{
auto [v,id]=sx.top(); sx.pop();
modify(1,1,n,sx.size()?(sx.top()[1]+1):1,id,x-v);
}
sx.push({x,i});
while(sy.size()&&x<sy.top()[0])
{
auto [v,id]=sy.top(); sy.pop();
modify(1,1,n,sy.size()?(sy.top()[1]+1):1,id,v-x);
}
sy.push({x,i});
modify(1,1,n,ls[i]+1,i,-1);
ans+=query(1,1,n,1,i);
}
return ans;
}
signed main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cout<<fixed<<setprecision(12);
int T=1;
cin>>T;
for(int i=1; i<=T; i++)
{
cout<<"Case #"<<i<<": "<<O_o()<<"\n";
}
}
首先预处理每个点要往后走到哪才会出现
k
k
k 次和
k
+
1
k+1
k+1 次
具体的,令
L
i
L_i
Li 为从点
i
i
i 往后走,出现
k
k
k 次
a
i
a_i
ai 的最近位置;令
R
i
R_i
Ri 为从点
i
i
i 往后走,出现
k
k
k 次
a
i
a_i
ai 的最远位置。
考虑倒着枚举左端点,对于每个左端点考虑有多少个右端点是合法的。
我们定义点 i i i 的合法区间为 [ L i , R i ] ∪ [ 1 , i − 1 ] [L_i,R_i]∪[1,i-1] [Li,Ri]∪[1,i−1] ( [ L i , R i ] [L_i,R_i] [Li,Ri] 中 a i a_i ai 出现了 k k k 次, [ 1 , i − 1 ] [1,i-1] [1,i−1] 不在 i i i 的管辖范围内),那么对于 i i i 为左端点的答案就是 [ i , n ] [i,n] [i,n] 中所有不同的数最前面的合法区间的交集。
也就是我们要维护一棵线段树,支持区间加、区间减、求区间最大值和最大值个数。这样做其实有些麻烦。
不难想到,合法区间的交集 = 不合法区间的并集的反集,求区间的并就完全可以像扫描线那样做。
#include
#define int long long
using namespace std;
const int N=1e6+7,inf=1e18;
struct seg
{
int val,len;
seg()
{
val=len=0;
}
};
vector<seg> tr;
int n;
void update(int u,int st,int ed)
{
if(tr[u].val>0)
{
tr[u].len=ed-st+1;
}
else
{
if(st==ed)
{
tr[u].len=0;
return;
}
tr[u].len=tr[u<<1].len+tr[u<<1|1].len;
}
}
void add(int u,int st,int ed,int l,int r,int x)
{
if(l>r||l>n||r>n) return;
if(l<=st&&ed<=r)
{
tr[u].val+=x;
update(u,st,ed);
return;
}
// pushdown(u);
int mid=st+ed>>1;
if(mid>=l)
add(u<<1,st,mid,l,r,x);
if(mid<r)
add(u<<1|1,mid+1,ed,l,r,x);
update(u,st,ed);
}
int query(int u,int st,int ed,int l,int r)
{
if(l>r||l>n||r>n) return 0;
if(l<=st&&ed<=r)
{
return tr[u].len;
}
int mid=st+ed>>1;
int res=0;
if(mid>=l)
res=query(u<<1,st,mid,l,r);
if(mid<r)
res+=query(u<<1|1,mid+1,ed,l,r);
return res;
}
void O_o()
{
int k;
cin>>n>>k;
map<int,vector<int>> mp;
vector<int> a(n+1);
for(int i=1; i<=n; i++)
{
cin>>a[i];
mp[a[i]].push_back(i);
}
tr.assign((n<<2)+1,seg());
vector<array<int,2>> pos(n+1);
vector<int> p,nxt(n+1);
p.push_back(-1);
for(auto [v,t]:mp)
{
p.push_back(v);
int m=t.size();
for(int i=0; i<m; i++)
{
int l,r;
if(i+k-1>=m)
{
l=n+1;
}
else
l=t[i+k-1];
if(i+k>=m)
{
r=n+1;
}
else
r=t[i+k];
pos[t[i]]={l,r};
if(i==m-1)
nxt[t[i]]=n+1;
else nxt[t[i]]=t[i+1];
}
}
int ans=0;
for(int i=n; i>=1; i--)
{
if(nxt[i]!=n+1)
{
auto [l,r]=pos[nxt[i]];
add(1,1,n,nxt[i],l-1,-1);
add(1,1,n,r,n,-1);
}
auto [l,r]=pos[i];
add(1,1,n,i,l-1,1);
add(1,1,n,r,n,1);
int t=query(1,1,n,i,n);
ans+=(n-i+1)-t;
}
cout<<ans<<"\n";
}
signed main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cout<<fixed<<setprecision(12);
int T=1;
cin>>T;
while(T--)
{
O_o();
}
}