有一棵n(1<=n<=1e18)个点的树,
点i连着2*i和2*i+1两个点,构成一棵完全二叉树
对于每个点i,记其值为a[i],a[i]可以取[1,m](1<=m<=1e5)的整数
记i到j的简单路径上的最大值为s[i][j],
则一棵权值确定的树对答案的贡献是
现在求所有可能情况下的树的贡献之和,答案对998244353取模
实际t<=200组样例,但保证summ不超过1e5
羊村群小羊
大致的思路就是把每个长度的路径都统计算出来,然后再算贡献
而n个点的树总是可以拆成左子树和右子树继续递归下去的,有子结构的概念
所以可以按子树大小做记忆化,每棵子树暴力维护所有长度的路径进行合并
由于路径长度最长2*logn,这里固定开了128长度的vector,只对这些做合并
dp[i][2]表示当前节点u的子树长度为i的路径的条数
其中dp[i][0]表示两端都位于子树内部的路径,dp[i][1]表示有一端位于根节点的路径
求出路径方案数后求贡献,最大值为i的方案数,首先特判i=1,
然后稍作容斥,方案数等于m个值从[1,i]任取减去m个值从[1,i-1]任取
长为i的路径的方案数*剩下n-i个点任取的方案数*最大值为j的方案数*最大值j,
就是当路径长度为i,而最大值为j时,(i,j)对答案的贡献,统计所有贡献累加即可
int k = std::__lg(n + 1);
ll ls=((1LL << (k - 1)) - 1) + std::min(1LL << (k - 1), n - (1LL << k) + 1);
ll rs=n-1-ls;
求左子树大小这里,抄了一下jiangly的代码,但后来想了想也很好理解
对于倒数第二层往上,是左右子树平分的
而对于最后一层,左子树能拿到的大小,为min(剩下的点数,最后一层的一半)
- #include
- using namespace std;
- #define rep(i,a,b) for(int i=(a);i<=(b);++i)
- #define per(i,a,b) for(int i=(a);i>=(b);--i)
- typedef long long ll;
- typedef double db;
- typedef array<int,2> P;
- #define fi first
- #define se second
- #define pb push_back
- #define dbg(x) cerr<<(#x)<<":"<
" " ; - #define dbg2(x) cerr<<(#x)<<":"<
- #define SZ(a) (int)(a.size())
- #define sci(a) scanf("%d",&(a))
- #define pt(a) printf("%d",a);
- #define pte(a) printf("%d\n",a)
- #define ptlle(a) printf("%lld\n",a)
- #define debug(...) fprintf(stderr, __VA_ARGS__)
- typedef unsigned ui;
- //typedef __uint128_t L;
- typedef unsigned long long L;
- typedef unsigned long long ull;
- const int N=1e5+10,M=128,mod=998244353;
- int t,m,pw[N][M];
- ll n;
- map
>mp;//dp[i][2]表示是否开口的方案数 - void add(int &x,int y){
- x=(x+y)%mod;
- }
- vector
dfs(ll n){
- if(n==0)return vector
(1,{0,0});
- if(n==1)return vector
(1,{0,1});
- if(mp.count(n))return mp[n];
- int k = std::__lg(n + 1);
- ll ls=((1LL << (k - 1)) - 1) + std::min(1LL << (k - 1), n - (1LL << k) + 1);
- ll rs=n-1-ls;
- vector
l=dfs(ls),r=dfs(rs);
- int sl=SZ(l),sr=SZ(r);
- //printf("n:%lld lsz:%d rsz:%d\n",n,sl,sr);
- vector
dp(128,{0,0});
- rep(i,0,sl-1){
- rep(j,0,sr-1){
- if(!l[i][1] || !r[j][1])continue;
- add(dp[i+j+2][0],1ll*l[i][1]*r[j][1]%mod);
- }
- }
- rep(i,0,sl-1){
- add(dp[i][0],l[i][0]);
- add(dp[i][0],l[i][1]);
- add(dp[i+1][1],l[i][1]);
- }
- rep(i,0,sr-1){
- add(dp[i][0],r[i][0]);
- add(dp[i][0],r[i][1]);
- add(dp[i+1][1],r[i][1]);
- }
- add(dp[0][1],1);
- return mp[n]=dp;
- }
- int modpow(int x,ll n,int mod){
- if(!n)return 1;
- int res=1;
- for(;n;n>>=1,x=1ll*x*x%mod){
- if(n&1)res=1ll*res*x%mod;
- }
- return res;
- }
- int cal(int sz,int v){
- if(v==1)return 1;
- return (pw[v][sz]-pw[v-1][sz]+mod)%mod;
- }
- int sol(){
- vector
dp=dfs(n);
- int sz=SZ(dp),res=0;
- rep(j,0,sz-1){
- int cnt=(dp[j][0]+dp[j][1])%mod,len=j+1;
- if(len>n)break;
- int oth=modpow(m,n-len,mod)%mod;
- rep(i,1,m){
- add(res,1ll*cnt*cal(len,i)%mod*i%mod*oth%mod);
- }
- }
- return res;
- }
- int main(){
- rep(i,1,N-1){
- pw[i][0]=1;
- rep(j,1,M-1){
- pw[i][j]=1ll*pw[i][j-1]*i%mod;
- }
- }
- sci(t);
- while(t--){
- scanf("%lld%d",&n,&m);
- printf("%d\n",sol());
- }
- return 0;
- }