传送门
经典的点分治,hash判断字符串的前缀和后缀;
1 #include <bits/stdc++.h> 2 using namespace std; 3 #define rep(i,a,b) for(int i=a;i<=b;++i) 4 #define fgr(i,u) for(int i=head[u];i;i=to[i]) 5 typedef long long ll; 6 typedef unsigned long long ull; 7 const int maxn=1000010,inf=0x3f3f3f3f; 8 ll ans; 9 inline int gi() { 10 int x=0; char o; bool f=true; for(;!isdigit(o=getchar());)if(o=='-') f=false; 11 for(;isdigit(o);o=getchar()) x=(x<<1)+(x<<3)+(o&15); return f?x:~x+1; 12 } 13 template <class T> bool check_Max(T &x, const T &y) { 14 if( x < y) { x = y; return false; } return true; 15 } 16 template <class T> bool check_Min(T &x, const T &y) { 17 if( x > y) { x = y; return false; } return true; 18 } 19 int v[maxn<<1],to[maxn<<1],head[maxn],p; 20 inline void link(int a,int b) { v[++p]=b; to[p]=head[a]; head[a]=p;} 21 int rt,_Max,sigma,vis[maxn],sz[maxn]; 22 inline void grt(int u,int pre) {//重心------------ 23 sz[u]=1;int Mx=0; 24 fgr(i,u) if(v[i]^pre&&!vis[v[i]]) 25 grt(v[i],u),sz[u]+=sz[v[i]],check_Max(Mx,sz[v[i]]); 26 check_Max(Mx,sigma-sz[u]); 27 if(!check_Min(_Max,Mx)) rt=u; 28 } 29 char S[maxn],T[maxn]; 30 int Len; 31 const ull base=19260817;//--------------hash 32 ull h1[maxn],h2[maxn],bis[maxn];//------hash 33 ll F[maxn],G[maxn],sf[maxn],sg[maxn]; 34 int dfs(int u,int pre,int dep,ull hs) { 35 sz[u]=1; hs=hs*base+T[u]; int tmp=1; 36 if(hs==h1[dep]) F[(dep-1)%Len+1]++, ans+=sg[Len-(dep-1)%Len]; 37 if(hs==h2[dep]) G[(dep-1)%Len+1]++, ans+=sf[Len-(dep-1)%Len]; 38 fgr(i,u) if(v[i]^pre&&!vis[v[i]]) 39 check_Max(tmp,dfs(v[i],u,dep+1,hs)+1),sz[u]+=sz[v[i]]; 40 return tmp; 41 } 42 inline void solve(int u) { 43 vis[u]=1; sg[1]=sf[1]=1; 44 int tmp=0,K=0; 45 // printf("%d %d - \n",u,ans);// de bug 46 fgr(i,u) if(!vis[v[i]]) { 47 K=min(Len,dfs(v[i],u,2,T[u])+1); check_Max(tmp,K); 48 rep(k,1,K) sf[k]+=F[k],sg[k]+=G[k],F[k]=G[k]=0; 49 } 50 rep(i,1,tmp) sf[i]=sg[i]=0; 51 fgr(i,u) if(!vis[v[i]]) 52 rt=v[i],sigma=sz[v[i]],_Max=inf,grt(v[i],u),solve(rt); 53 } 54 int Tim,n; 55 int main() { 56 #ifndef ONLINE_JUDGE 57 freopen("1a.in","r",stdin); 58 #endif 59 for(Tim=gi();Tim;--Tim) { 60 memset(head,0,sizeof(head)); memset(vis,0,sizeof(vis)); p=0; 61 ans=0; 62 n=gi(); Len=gi(); scanf("%s",T+1); 63 rep(i,2,n) {int u=gi(),v=gi();link(u,v);link(v,u);} 64 scanf("%s",S+1); 65 bis[0]=1;h1[0]=h2[0]=0; 66 rep(i,1,n) {bis[i]=bis[i-1]*base; h1[i]=h1[i-1]+bis[i-1]*S[(i-1)%Len+1]; h2[i]=h2[i-1]+bis[i-1]*S[Len-(i-1)%Len];} 67 rt=1; sigma=n; _Max=inf; grt(1,0); solve(rt); 68 printf("%lld\n",ans); 69 } 70 return 0; 71 }