不知道这题能不能发出来,如果不能请联系我,我什么都会做的
题意:给一棵 nnn 个结点的树,每个结点有个 ax+bax+bax+b,求所有根到叶子的乘积之和。系数模 998244353998244353998244353。
链的情况就是分治 NTT,所以树上没有弱于这个的做法。
考虑链分治,先对树做长链剖分,然后对根所在的链分治,维护两个多项式,一个链上所有结点的乘积,一个从区间起点往下走,从区间中某个位置拐出去,走到所有叶子的路径乘积之和。递归到分治树的叶子的时候就递归算原树上的轻儿子。
为了保证复杂度,NTT 的长度应该开当前区间所有虚儿子的最大深度和区间长度的较大值,而非区间起点的深度。这样每条链只会在链头的父亲所在的链 分治的时候贡献 O(logn)\Omicron(\log n)O(logn) 次 NTT 的长度,总复杂度是 O(nlog2n)\Omicron(n\log^2n)O(nlog2n),并且上界很松。
第一次写封装多项式,挺舒服的
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#define MAXN ((1<<18)+5)
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD:x-y;}
inline int qpow(int a,int p)
{
int ans=1;
while (p)
{
if (p&1) ans=(ll)ans*a%MOD;
a=(ll)a*a%MOD,p>>=1;
}
return ans;
}
#define inv(x) qpow(x,MOD-2)
int rt[2][24];
int r[MAXN],l,lim;
inline void init(){lim=1<<l;for (int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));}
void ntt(int* a,int type)
{
for (int i=0;i<lim;i++) if (i<r[i]) swap(a[i],a[r[i]]);
for (int L=0;L<l;L++)
{
int mid=1<<L,len=mid<<1;
int Wn=rt[type][L+1];
for (int s=0;s<lim;s+=len)
{
ll w=1;
for (int k=0;k<mid;k++,w=w*Wn%MOD)
{
int x=a[s+k],y=w*a[s+mid+k]%MOD;
a[s+k]=add(x,y),a[s+mid+k]=dec(x,y);
}
}
}
if (type)
{
int t=inv(lim);
for (int i=0;i<lim;i++) a[i]=(ll)a[i]*t%MOD;
}
}
struct poly
{
int *a,n;
inline poly():n(0){}
inline poly(int x):n(x){a=new int[x];memset(a,0,sizeof(int)*n);}
inline poly(int x,int y):n(2){a=new int[2];a[0]=x,a[1]=y;}
inline int& operator [](const int& i){return a[i];}
inline const int& operator [](const int& i)const{return a[i];}
};
inline poly operator *(const poly& a,const poly& b)
{
static int ta[MAXN],tb[MAXN];
poly c(a.n+b.n-1);
for (l=0;(1<<l)<c.n;++l);
init();
for (int i=0;i<lim;i++) ta[i]=tb[i]=0;
for (int i=0;i<a.n;i++) ta[i]=a[i];
for (int i=0;i<b.n;i++) tb[i]=b[i];
ntt(ta,0),ntt(tb,0);
for (int i=0;i<lim;i++) ta[i]=(ll)ta[i]*tb[i]%MOD;
ntt(ta,1);
for (int i=0;i<c.n;i++) c[i]=ta[i];
return c;
}
inline poly operator +(const poly& a,const poly& b)
{
poly c(max(a.n,b.n));
for (int i=0;i<c.n;i++) c[i]=add(i<a.n? a[i]:0,i<b.n? b[i]:0);
return c;
}
vector<int> e[MAXN];
int buf[MAXN],*tp=buf;
int fa[MAXN],son[MAXN],mx[MAXN];
int *lis[MAXN];
inline int* newbuf(int x){int* p=tp;tp+=x;return p;}
void dfs(int u,int f)
{
fa[u]=f;
for (int i=0;i<(int)e[u].size();i++)
if (e[u][i]!=f)
{
dfs(e[u][i],u);
if (mx[e[u][i]]>mx[son[u]]) son[u]=e[u][i];
}
mx[u]=mx[son[u]]+1;
}
void dfs(int u,int* cur)
{
*(lis[u]=cur)=u;
if (son[u]) dfs(son[u],cur+1);
for (int i=0;i<(int)e[u].size();i++)
if (e[u][i]!=fa[u]&&e[u][i]!=son[u])
dfs(e[u][i],newbuf(mx[e[u][i]]));
}
int rval[MAXN],gval[MAXN];
pair<poly,poly> solve(int* L,int* R)
{
if (L==R)
{
int u=*L;
poly tmp;
for (int i=0;i<(int)e[u].size();i++)
if (e[u][i]!=fa[u]&&e[u][i]!=son[u])
tmp=tmp+solve(lis[e[u][i]],lis[e[u][i]]+mx[e[u][i]]-1).second;
if ((int)e[u].size()==(fa[u]>0)) tmp=poly(1),tmp[0]=1;
return make_pair(poly(rval[u],gval[u]),poly(rval[u],gval[u])*tmp);
}
int* mid=L+((R-L)>>1);
pair<poly,poly> lans=solve(L,mid),rans=solve(mid+1,R);
return make_pair(lans.first*rans.first,lans.first*rans.second+lans.second);
}
poly ans;
int main()
{
freopen("slime.in","r",stdin);
freopen("slime.out","w",stdout);
rt[0][23]=qpow(3,119),rt[1][23]=inv(rt[0][23]);
for (int i=22;i>=0;i--)
{
rt[0][i]=(ll)rt[0][i+1]*rt[0][i+1]%MOD;
rt[1][i]=(ll)rt[1][i+1]*rt[1][i+1]%MOD;
}
int n=read();read();
for (int i=1;i<=n;i++) rval[i]=read();
for (int i=1;i<=n;i++) gval[i]=read();
for (int i=1;i<n;i++)
{
int u,v;
u=read(),v=read();
e[u].push_back(v),e[v].push_back(u);
}
dfs(1,0);
dfs(1,newbuf(mx[1]));
ans=solve(lis[1],lis[1]+mx[1]-1).second;
for (int i=0;i<=n;i++) printf("%d\n",(i<ans.n? ans[i]:0));
return 0;
}
1446

被折叠的 条评论
为什么被折叠?



