「是男人就过8题——Pony.ai」IntervalTree

「是男人就过8题——Pony.ai」IntervalTree

定义区间树为线段树的拓展,即每次断开的位置可以不是线段的中心。

给定一个 $[1, n]$ 的区间树和 $q$ 次询问,每次询问包含一个正整数 $k$, 你需要求出有多少区间的时间复杂度恰好等于 $k$。

$n, q\le 10^5,\ k\le 10^9$。

题解

在线回答询问无意义,考虑利用生成函数处理出所有询问的答案。

询问 $[l;r]$ 选中的线段($ql=l \land qr=r$ 的线段,而非经过的线段),LCA 往两侧深度单调减(且中间平的一段的长度至多为 $2$)。

求出往两侧单调的生成函数合并,类似:

其中 $L,R$ 分别表示左/右端点和当前线段的左/右端点相同的线段(不包括完全相同的情况)的生成函数,$u$ 是当前节点,$l$ 是左儿子,$r$ 是右儿子。

处理一些平凡情况,在断点计算贡献:

就能做到 $\text{polylog} \times \sum \small{\text{线段长度}}$ 复杂度。

进一步优化复杂度,考虑边分治:假设当前处理子树 $u$,边分的子树 $v$。递归处理出 $u \leftrightarrow v$ 的路径上的 $S,L,R$。下面考虑 $L_v,R_v$ 对路径上点的 $S$ 和 $u$ 的 $L,R$ 的贡献。

前者可以分别考虑 $L_v,R_v$ 的贡献,通过两次卷积得到。后者则是路径上的 $L,R$ 通过一定位移得到。

至于处理可以分别在两侧继续边分。复杂度 $O(n \log^2 n)$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include<bits/stdc++.h>
template<class T> inline void read(T &x){
x=0; register char c=getchar(); register bool f=0;
while(!isdigit(c))f^=c=='-',c=getchar();
while(isdigit(c))x=x*10+c-'0',c=getchar(); if(f)x=-x;
}
template<class T> inline void print(T x){
if(x<0)putchar('-'),x=-x;
if(x>9)print(x/10); putchar(x%10+'0');
}
template<class T> inline void print(T x,char c){print(x),putchar(c);}
const int N=2e5+10,mod=998244353;
int _,n,m,cnt,siz[N],vis[N],mid[N],ch[N][2],fa[N],l[N],r[N],dep[N],rev[N<<2];
struct z {
int x;
z(int x=0):x(x){}
friend inline z operator*(z a,z b){return (long long)a.x*b.x%mod;}
friend inline z operator-(z a,z b){return (a.x-=b.x)<0?a.x+mod:a.x;}
friend inline z operator+(z a,z b){return (a.x+=b.x)>=mod?a.x-mod:a.x;}
}w[N<<2];
std::vector<z> ans;
inline z fpow(z a,int b){z s=1;for(;b;b>>=1,a=a*a)if(b&1)s=s*a;return s;}
inline void print(const std::vector<z> &a){for(int i=0;i<a.size();i++)printf("%d ",a[i].x); printf("\n");}
void dft(std::vector<z> &a,int lim){
a.resize(lim);
for(int i=0;i<lim;i++)if(i<rev[i])std::swap(a[i],a[rev[i]]);
for(int len=1;len<lim;len<<=1)
for(int i=0;i<lim;i+=(len<<1))
for(int j=0;j<len;j++){
z x=a[i+j],y=a[i+j+len]*w[j+len];
a[i+j]=x+y,a[i+j+len]=x-y;
}
}
std::vector<z> operator+(std::vector<z> a,const std::vector<z> &b){
a.resize(std::max(a.size(),b.size()));
for(int i=0;i<b.size();i++)a[i]=a[i]+b[i]; return a;
}
std::vector<z> operator-(std::vector<z> a,const std::vector<z> &b){
a.resize(std::max(a.size(),b.size()));
for(int i=0;i<b.size();i++)a[i]=a[i]-b[i]; return a;
}
std::vector<z> operator*(std::vector<z> a,std::vector<z> b){
int len=a.size()+b.size()-1,lim=1,k=0;
while(lim<len)lim<<=1,++k;
for(int i=0;i<lim;i++)rev[i]=rev[i>>1]>>1|((i&1)<<(k-1));
dft(a,lim),dft(b,lim);
for(int i=0;i<lim;i++)a[i]=a[i]*b[i];
dft(a,lim),std::reverse(&a[1],&a[lim]);
z inv=fpow(lim,mod-2);
for(int i=0;i<lim;i++)a[i]=a[i]*inv;
return a.resize(len),a;
}
void shift(z x,std::vector<z> &dst,size_t dta){
dst.resize(std::max(dst.size(),dta+1));
dst[dta]=dst[dta]+x;
}
void shift(const std::vector<z> &src,std::vector<z> &dst,size_t dta){
dst.resize(std::max(dst.size(),src.size()+dta));
for(int i=0;i<src.size();i++)dst[i+dta]=dst[i+dta]+src[i];
}
void dfsInit(int &u,int l,int r,int dep){
if(l==r){
u=n-1+l;
}else{
u=++cnt;
dfsInit(ch[u][0],l,mid[u],dep+1),fa[ch[u][0]]=u;
dfsInit(ch[u][1],mid[u]+1,r,dep+1),fa[ch[u][1]]=u;
}
::l[u]=l,::r[u]=r,::dep[u]=dep;
}
int calcSize(int u){
if(!u||vis[u])return 0;
return siz[u]=1+calcSize(ch[u][0])+calcSize(ch[u][1]);
}
std::pair<int,int> findSubTree(int u,int lim){
if(!u||vis[u])return {-1,-1};
if(siz[u]<lim)return {u,siz[u]};
std::pair<int,int> x=ch[u][0]?findSubTree(ch[u][0],lim):std::make_pair(-1,-1);
std::pair<int,int> y=ch[u][1]?findSubTree(ch[u][1],lim):std::make_pair(-1,-1);
return x.second>y.second?x:y;
}
void calc(bool fl,int u,int mov,std::vector<z> &f){
if(vis[u]||l[u]==r[u])return shift(1,f,mov+1);
shift(1,f,mov+1);
shift(mod-1,f,mov+3);
calc(fl,ch[u][0],mov+(fl?2:1),f);
calc(fl,ch[u][1],mov+(fl?1:2),f);
}
std::pair<std::vector<z>,std::vector<z>> fuck(int u){
if(vis[u]||l[u]==r[u])return {{0,1},{0,1}};
std::vector<z> Ll,Lr,Rl,Rr,Lu,Ru;
Lu=Ru=std::vector<z>{0,1,0,mod-1};
std::tie(Ll,Rl)=fuck(ch[u][0]);
std::tie(Lr,Rr)=fuck(ch[u][1]);
shift(Ll,Lu,1),shift(Lr,Lu,2);
shift(Rr,Ru,1),shift(Rl,Ru,2);
shift(Rl*Lr,ans,dep[u]);
return {Lu,Ru};
}
std::pair<std::vector<z>,std::vector<z>> solve(int u){
if(vis[u]||l[u]==r[u])return {{0,1},{0,1}};
int siz=calcSize(u);
int v=findSubTree(u,(siz*2)/4).first;
if(v==-1)return fuck(u);
std::vector<z> Lu,Ru,Lv,Rv,Lt,Rt,T;
std::tie(Lv,Rv)=solve(v);
vis[v]=1;
std::tie(Lu,Ru)=solve(u);
vis[v]=0;
int lmov=0,rmov=0;
for(int p=v;p!=u;lmov+=ch[fa[p]][0]==p?1:2,rmov+=ch[fa[p]][0]==p?2:1,p=fa[p]){
int f=fa[p],q=ch[f][0]==p?ch[f][1]:ch[f][0];
if(ch[f][0]==p){
T.clear(),calc(0,q,0,T),shift(T,Lt,rmov+dep[f]-dep[u]);
}else{
T.clear(),calc(1,q,0,T),shift(T,Rt,lmov+dep[f]-dep[u]);
}
}
shift(mod-1,Lv,1);
shift(mod-1,Rv,1);
shift(Lv*Rt,ans,dep[u]);
shift(Rv*Lt,ans,dep[u]);
shift(Lv,Lu,lmov);
shift(Rv,Ru,rmov);
return {Lu,Ru};
}
void solution(){
for(int i=1;i<n;i++)read(mid[i]);
dfsInit(_,1,n,1);
solve(1);
for(int i=1;i<(n<<1);i++)shift(1,ans,dep[i]);
for(int i=1;i<n;i++)shift(mod-1,ans,dep[i]+2);
for(int i=1,q;i<=m;i++){
read(q);
print(q<ans.size()?ans[q].x:0,'\n');
}
}
void recycle(){
cnt=0;
ans.clear();
memset(ch,0,sizeof(ch));
memset(fa,0,sizeof(fa));
}
int main(){
for(int len=1;len<(N<<1);len<<=1){
z wn=fpow(3,(mod-1)/(len<<1)); w[len]=1;
for(int i=1;i<len;i++)w[i+len]=w[i+len-1]*wn;
}
while(~scanf("%d%d",&n,&m))solution(),recycle();
}