YBTOJ 树状数组 二进制
qwq
这道题自己没想出来,是看全网首A的神仙的博客学会的,写这篇文章主要是思路梳理吧。
考虑 x x x a n d and and 2 k 2^k 2k 在二进制第 k + 1 k+1 k+1 位的值:当 x ∈ [ 0 , 2 k − 1 ] x\in[0,2^k-1] x∈[0,2k−1] 时为 0 0 0, x ∈ [ 2 k , 2 k + 1 − 1 ] x\in[2^k,2^{k+1}-1] x∈[2k,2k+1−1] 时为 1 1 1, x ∈ [ 2 k + 1 , 2 k + 1 + 2 k − 1 ] x\in[2^{k+1},2^{k+1}+2^k-1] x∈[2k+1,2k+1+2k−1] 时又为 0 0 0……以此类推,得出一般性结论:当且仅当 x m o d 2 k + 1 ∈ [ 2 k , 2 k + 1 − 1 ] x\mod 2^{k+1}\in[2^k,2^{k+1}-1] xmod2k+1∈[2k,2k+1−1]时 x x x a n d and and 2 k 2^k 2k 在二进制第 k + 1 k+1 k+1 位的值为 1 1 1。
那么想要求 ( a [ i ] + x ) (a[i]+x) (a[i]+x) a n d and and y y y 的值,就可以对 y y y 进行二进制拆分, y y y 的第 k + 1 k+1 k+1 位为 1 1 1 时满足 a [ i ] ∈ [ 2 k − x , 2 k + 1 − x − 1 ] a[i]\in[2^k-x,2^{k+1}-x-1] a[i]∈[2k−x,2k+1−x−1] 的 i i i 对答案的贡献为 2 k 2^k 2k。
于是考虑开 20 20 20 个树状数组,编号为 i i i 的树状数组的第 j j j 位表示有多少 x x x 满足 a [ x ] m o d 2 i + 1 a[x]\mod 2^{i+1} a[x]mod2i+1 的值为 j − 1 j-1 j−1,树状数组支持的操作是求前缀和,那么第 k k k 位对答案的贡献变可以转化为 ( q u e r y ( 2 k + 1 − x − 1 ) − q u e r y ( 2 k − x − 1 ) ) × 2 k (query(2^{k+1}-x-1)-query(2^k-x-1))\times 2^k (query(2k+1−x−1)−query(2k−x−1))×2k,再利用循环节的性质处理一下负数的情况即可。
好巧妙的思路qwq
#include<bits/stdc++.h>
#define ll long long
#define ff(i,s,e) for(int i=(s);i<=(e);++i)
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
const int N=1e5+5,M=(1<<20)+5;
int n,q,a[N];
int len[22],t[22][M];
inline int lowbit(int x){return x&(-x);}
inline void upd(int pos,int x,int val){
++x;
for(int i=x;i<=len[pos];i+=lowbit(i)) t[pos][i]+=val;
}
inline int query(int pos,int x){
int res=0;++x;
for(int i=x;i;i-=lowbit(i)) res+=t[pos][i];
return res;
}
signed main(){
n=read(),q=read();
ff(i,0,19) len[i]=(1<<i+1);
ff(i,1,n){
a[i]=read();
ff(j,0,19) upd(j,a[i]%(1<<j+1),1);
}
int op,x,y;
while(q--){
op=read(),x=read(),y=read();
if(op==1){
ff(j,0,19) upd(j,a[x]%(1<<j+1),-1),upd(j,y%(1<<j+1),1);
a[x]=y;
}
else{
ll ans=0;
ff(j,0,19){
if((y&(1<<j))==0) continue;
int l=(1<<j),r=(1<<j+1)-1;
l=(l-1-x+(1<<20))%(1<<j+1);
r=(r-x+(1<<20))%(1<<j+1);
// cout<<l<<' '<<r<<endl;
if(l<=r) ans+=(1ll*(query(j,r)-query(j,l))<<j);
else ans+=(1ll*(query(j,len[j]-1)+query(j,r)-query(j,l))<<j);
}
printf("%lld\n",ans);
}
}
return 0;
}