为了方便才用lct,没想到最后要加读入优化才能过...
有一个结论就是在一条链上,如果能找到一个点使得这个点划分链左右两边的树节点权值和最相近,那么这个点就是答案
用lct维护,每个splay节点存树节点权值$v_x$,树边权值$w_x$,splay中最左节点权值$lv_x$,最右节点权值$rv_x$,树节点权值和$sv_x$,树边权值和$sw_x$,这棵子树向左贡献的答案$pl_x$,这棵子树向右贡献的答案$pr_x$
对于每个询问,先把对应的链提取出来,然后在这棵splay上二分找到分两边点权和最平均的点(二分过程用$sv$和$lv$判断),找到点之后就可以直接输出答案了
修改直接修改
#include<stdio.h>
#define NUM(x) (48<=x&&x<=57)
char c[21000000]={0};
int ns=0;
inline int rd(){
while(!NUM(c[ns]))ns++;
int q=0;
while(NUM(c[ns]))q=(q<<3)+(q<<1)+c[ns++]-48;
return q;
}
#define ll long long
int fa[320010],ch[320010][2],r[320010];
ll v[320010],lv[320010],rv[320010],w[320010],sv[320010],sw[320010],pl[320010],pr[320010];
#define ls ch[x][0]
#define rs ch[x][1]
void pushup(int x){
lv[x]=ls?lv[ls]:v[x];
rv[x]=rs?rv[rs]:v[x];
sv[x]=sv[ls]+sv[rs]+v[x];
sw[x]=sw[ls]+sw[rs]+w[x];
pl[x]=pl[ls]+pl[rs]+v[x]*sw[ls]+(sw[ls]+w[x])*sv[rs];
pr[x]=pr[ls]+pr[rs]+v[x]*sw[rs]+(sw[rs]+w[x])*sv[ls];
}
templatevoid swap(C&a,C&b){
C c=a;
a=b;
b=c;
}
void rev(int x){
r[x]^=1;
swap(lv[x],rv[x]);
swap(pl[x],pr[x]);
swap(ls,rs);
}
void pushdown(int x){
if(r[x]){
if(ls)rev(ls);
if(rs)rev(rs);
r[x]=0;
}
}
void rot(int x){
int y,z,f,b;
y=fa[x];
z=fa[y];
f=ch[y][0]==x;
b=ch[x][f];
fa[x]=z;
fa[y]=x;
if(b)fa[b]=y;
ch[x][f]=y;
ch[y][f^1]=b;
if(ch[z][0]==y)ch[z][0]=x;
if(ch[z][1]==y)ch[z][1]=x;
pushup(y);
pushup(x);
}
bool isrt(int x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;}
void gao(int x){
if(!isrt(x))gao(fa[x]);
pushdown(x);
}
void splay(int x){
int y,z;
gao(x);
while(!isrt(x)){
y=fa[x];
z=fa[y];
if(!isrt(y))rot((ch[z][0]==y&&ch[y][0]==x)||(ch[z][1]==y&&ch[y][1]==x)?y:x);
rot(x);
}
}
void access(int x){
int y=0;
while(x){
splay(x);
rs=y;
pushup(x);
y=x;
x=fa[x];
}
}
void makert(int x){
access(x);
splay(x);
rev(x);
}
void link(int x,int y){
makert(x);
fa[x]=y;
}
int find(int x,ll d){
pushdown(x);
if(sv[ls]+v[x]>d)return find(ls,d);
d-=sv[ls]+v[x];
if(rs&&lv[rs]<=d)return find(rs,d);
return x;
}
ll query(int x,int y){
makert(x);
access(y);
splay(x);
if(lv[x]<=sv[x]>>1){
x=find(x,sv[x]>>1);
splay(x);
x=rs;
}
pushdown(x);
while(ls){
x=ls;
pushdown(x);
}
splay(x);
return pr[ls]+pl[rs];
}
int main(){
int len=fread(c,1,21000000,stdin);
c[len]=0;
int n,q,i,x,y;
n=rd();
for(i=1;i<=n;i++){
v[i]=rd();
pushup(i);
}
for(i=1;i<n;i++){
x=rd();
y=rd();
w[n+i]=rd();
pushup(n+i);
link(x,n+i);
link(n+i,y);
}
q=rd();
while(q--){
i=rd();
x=rd();
y=rd();
if(i==1)
printf("%lld\n",query(x,y));
else{
splay(x);
v[x]=y;
pushup(x);
}
}
}