前言
有线段树合并就应该有线段树分裂。它是线段树合并的逆过程。具体的,你需要以权值线段树中第 k 小的数为分界线,把线段树分成两半。
算法流程
和线段树上二分类似。假设原来的线段树为 u,要分裂出线段树 v
- 记左子树的权值为 val。
- 如果 k>val,那么分界线在右子树,那么左子树归 u,递归右子树,此时 k=k-val。
- 如果 k==val,那么分界线正好就是mid,那么左子树归 u,右子树归 v。
- 如果 k<val,那么分界线在左子树,那么右子树归 v,递归左子树
- 计算 u,v 的权值 tr[v].val=tr[u].val-k; tr[u].val=k;
核心代码
int split(int u,int v,int st,int ed,int k)
{if(u==0) return 0;int mid=st+ed>>1;tr.push_back(seg());v=tr.size()-1;int val=tr[tr[u].ls].val;if(k>val)tr[v].rs=split(tr[u].rs,tr[v].rs,mid+1,ed,k-val);elseswap(tr[u].rs,tr[v].rs);if(k<val)tr[v].ls=split(tr[u].ls,tr[v].ls,st,mid,k);tr[v].val=tr[u].val-k;tr[u].val=k;return v;
}
【模板】线段树分裂
题解
操作0
先把线段树分裂成 ( 1 , x − 1 ) , ( x , n ) (1,x-1),(x,n) (1,x−1),(x,n),再把 ( x , n ) (x,n) (x,n) 线段树分裂成 ( x , y ) , ( y + 1 , n ) (x,y),(y+1,n) (x,y),(y+1,n),最后合并线段树 ( 1 , x − 1 ) , ( y + 1 , n ) (1,x-1),(y+1,n) (1,x−1),(y+1,n)。
操作1
线段树合并
操作2
单点加
操作3
区间查
操作4
线段树上二分
代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+7,inf=1e18;
int n,m;
struct seg
{int ls,rs,val;seg():ls(0),rs(0),val(0){}seg(int a,int b,int c):ls(a),rs(b),val(c){}
};
vector<seg> tr(2);
void update(int u)
{tr[u].val=tr[tr[u].ls].val+tr[tr[u].rs].val;
}
void insert(int u,int st,int ed,int x,int t)
{if(st==ed){tr[u].val+=t;return;}int mid=st+ed>>1;if(x<=mid){if(!tr[u].ls){tr.push_back(seg());tr[u].ls=tr.size()-1;}insert(tr[u].ls,st,mid,x,t);}else{if(!tr[u].rs){tr.push_back(seg());tr[u].rs=tr.size()-1;}insert(tr[u].rs,mid+1,ed,x,t);}update(u);
}
int query(int u,int st,int ed,int l,int r)
{if(l<=st&&ed<=r){return tr[u].val;}int mid=st+ed>>1,res=0;if(l<=mid){if(tr[u].ls)res+=query(tr[u].ls,st,mid,l,r);}if(mid<r){if(tr[u].rs)res+=query(tr[u].rs,mid+1,ed,l,r);}return res;
}
void merge(int u,int v,int st,int ed)
{if(st==ed){tr[u].val+=tr[v].val;return;}int mid=st+ed>>1;if(tr[u].ls&&tr[v].ls)merge(tr[u].ls,tr[v].ls,st,mid);else if(tr[v].ls)tr[u].ls=tr[v].ls;if(tr[u].rs&&tr[v].rs)merge(tr[u].rs,tr[v].rs,mid+1,ed);else if(tr[v].rs)tr[u].rs=tr[v].rs;update(u);
}
int split(int u,int v,int st,int ed,int k)
{if(u==0) return 0;int mid=st+ed>>1;tr.push_back(seg());v=tr.size()-1;int val=tr[tr[u].ls].val;if(k>val)tr[v].rs=split(tr[u].rs,tr[v].rs,mid+1,ed,k-val);elseswap(tr[u].rs,tr[v].rs);if(k<val)tr[v].ls=split(tr[u].ls,tr[v].ls,st,mid,k);tr[v].val=tr[u].val-k;tr[u].val=k;return v;
}
int find(int u,int st,int ed,int k)
{if(k>tr[u].val||st>ed||k==0)return -1;if(st==ed){return st;}int mid=st+ed>>1;int val=tr[tr[u].ls].val;if(k>val){return find(tr[u].rs,mid+1,ed,k-val);}else{return find(tr[u].ls,st,mid,k);}
}
vector<int> rt(1);
void O_o()
{cin>>n>>m;rt.push_back(1);for(int i=1; i<=n; i++){int x;cin>>x;insert(rt[1],1,n,i,x);}while(m--){int op,id;cin>>op>>id;if(op==0){int x,y;cin>>x>>y;rt.push_back(0);int now=rt.size()-1;int v1=query(rt[id],1,n,1,x-1),v2=query(rt[id],1,n,x,y);rt[now]=split(rt[id],rt[now],1,n,v1);int t=split(rt[now],0,1,n,v2);merge(rt[id],t,1,n);}else if(op==1){int t;cin>>t;merge(rt[id],rt[t],1,n);}else if(op==2){int x,q;cin>>x>>q;insert(rt[id],1,n,q,x);}else if(op==3){int l,r;cin>>l>>r;cout<<query(rt[id],1,n,l,r)<<"\n";}else if(op==4){int k;cin>>k;cout<<find(rt[id],1,n,k)<<"\n";}else assert(0);}
}
signed main()
{ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);cout<<fixed<<setprecision(12);int T=1;
// cin>>T;while(T--){O_o();}
}