persistent segment tree (永続segtree)

永続データ構造っていうのは"クエリとかが来て状態が変更された後も,変更される前の構造にアクセスできる"みたいな感じのやつ.
永続segtreeの例としては,まず普通のsegtreeとしてRMQがあるわけだけど,そのクエリとして,
1:"クエリxの直前の状態での"[l,r)のmin
2:a[i]をxに
みたいなのが来ると永続の出番になる.(この場合クエリ先読みできれば永続いらないけど)
勿論全ての状態のsegtreeを持っておくと(クエリQ個,範囲Nに対して)QNlogNとかになって論外(ここではQ,N<150000くらいを想定している).これをうまくやるアイデアは単純で,クエリ2でsegtreeの一部が変更されるわけだが,変更されるのは高々O(logn)個だということ.
つまり,これをf:id:sigma425:20141230044150j:plain
こうじゃf:id:sigma425:20141230044149j:plain
こうすることで全ノードの個数はO(NlogN+QlogN)とかになり大勝利.
またその区間に関する計算(のやりなおし)もO(logN)でできる(下から再帰的にやれば良い)
どこかを変更すれば一番上のrootは必ず新しい別のやつになってるので,あるバージョンにアクセスしたいときはそのrootを覚えておいて使えば良い.(例:root[0]=0,root[1]=15)
idは適当(ノード作った順番とか)にするので,id*2+1,id*2+2みたいなきれいな形では持てない.lchとrch(l,r-child)を番号で持てばよい.(例:lch[1]=3,rch[1]=4,lch[16]=3,rch[16]=17)

簡単じゃん,となるが,実装(やその実装によって起こる細かい注意点)がちょっとめんどい.
nodeを番号で持っておくと,val[lch[x]]とかが多発してわかりにくかったりするため(またカプセル化のため),

まず struct nodeを定義し,そのポインタを返す関数として色んなものが実装されていることが多い.
- node
nodeの中には,
node *lch,*rch
int mn
のように,左右の子へのポインタとsegtreeに必要な情報が入っている.(こうすれば前述のval[lch[x]]はx->lch->mnになって直感的にわかりやすい?)
- news
またポインタにはするけれど,毎回newをするよりも連続した配列をとっておくほうが良いので,node pool[2000000]とかをとっておき,インクリメントさせながら使っていく(これは配列の時と同じだが,違いは,実装内でpoolの添字が明示的に現れることはないこと).これがnews関数(一般的な名称ではない.my_newが普通?).指定されたノードを新しく作り,そのポインタを返す関数.

自分の実装では3種類書くことになる
まず node pool[]の宣言時は何もしなくても良いので,空のコンストラクタを呼ぶことにする.
newsの宣言時に必要なのはnodeを作る時に必要な情報で,node*型のlch,rch,int型のmnを渡せば新しいノードを返すようなコンストラクタを書けば良い.(情報をすべて渡す)
もう一つ,子供の情報がわかっていてnode *lch,*rchを渡すだけでそのノードを新しく作成する(下から再帰していく時に使う)奴も必要.それに対応するnewsも作っておく.
- nil
葉のやつのダミーの子として使うnode*型の値.名前は赤黒木のいつもの図から取ってきた.nilのかわりにNULLをつかってもいいのだが,そうすると場合分けが発生してしまう,というのも,nilの情報にアクセスすることがあるから.そのときNULLをつかっているとsegmentation faultになり場合分けが必要になる.segtreeの情報としてmnくらいしか持ってないなら別に楽だからいいんだけど,いっぱい情報を持っていてその情報を得たくなる毎に場合分けをするのはさすがに面倒なので,nilにその場合わけの情報をまとめる(例えば木のサイズは0,mnはinfを持っている).また,nilの子供にアクセスすることもある(再帰時)ので,nilを適当に宣言した後,nil->lchとnil->rchもnilに変えなければいけない(アドレスの恩恵).
- fix
fix(node *x,int pos,int l,int r)で,"xのところ([l,r))以下で必要なところには新しいノードを作る,そして自分の(新しい)ノードのポインタを返す"という関数.さっきも言ったように,r-l==1なら新しく作りたいノードのポインタをかえせばよい.それ以外なら,左右に分けて,そのうち変更がないものは元のまま(つまりxの子(例えばx->lch)をlchとして持つ),新しいのは再帰で作ったやつをもつ,ような新しいノードのポインタを返せば良い.前述のようにnewsを作っておけばここできれいなコードになる.

  • query

やるだけ,いつもどおり.segmentのidのかわりにポインタを持ち,左右に割り振る.ポインタがnilの時はちゃんと別に処理する.

もう一つ注意点があって,nilだからといってr-l==1とは限らないということ.big range query(N<=10^9くらいで必要なとこだけ作るsegtree)を知っている人ならこれは理解できると思う.例えば初期状態ではroot[0]をnilに設定する(まだ何もしてないので).

つまりさっきの図の例としてよりよいものはこんな感じになる
これを
f:id:sigma425:20141230165303j:plain
こうじゃ
f:id:sigma425:20141230165259j:plain
(nilはpool[0]に対応していて,下から再帰的にnodeを作るのでこんな感じになる(実装には関係ないけど)

また,書いてないけど区間更新が来る奴も同様にできます.(普通のsegtreeで出来るようなのじゃないとダメだけど)

永続segtreeの例題を2つあげてコードも置いておきます.

NPCA Judge
[l,r)の個数をカウント,a[i]=x+永続

Problem - E - Codeforces
考察すると適当なsegtreeが必要なことがわかる,それの永続化が必要.

1問目:

//https://judge.npca.jp/problems/view/97
#include <iostream>
#include <cstdio>
#define rep(i,n) for(int i=0;i<n;i++)
using namespace std;
const int MX=200000*19;
struct node{
	node *lch,*rch;
	int num;
	node(node *l,node *r,int n){
		lch=l,rch=r,num=n;
	}
	node(node *l,node *r){
		lch=l,rch=r,num=lch->num+rch->num;
	}
	node(){}
};
node pool[MX];
node *root[200001];
node *nil;
int sit;
node *news(node *lch,node *rch,int num){
	return &(pool[sit++]=node(lch,rch,num));
}
node *news(node *lch,node *rch){
	return &(pool[sit++]=node(lch,rch));
}
node *fix(node *x,int pos,int l,int r){
//	printf("pos,l,r=(%d,%d,%d)\n",pos,l,r);
	if(r-l==1){
		return news(nil,nil,x->num+1);
	}
	int m=(l+r)/2;
	if(pos<m){
		return news(fix(x->lch,pos,l,m),x->rch);
	}else{
		return news(x->lch,fix(x->rch,pos,m,r));
	}
}
int getnum(int a,int b,int l,int r,node *x){
	if(x==nil) return 0;
	if(b<=l||r<=a) return 0;
	if(a<=l&&r<=b) return x->num;
	int m=(l+r)/2;
	return getnum(a,b,l,m,x->lch)+getnum(a,b,m,r,x->rch);
}
void showtree(node *x){
	if(x==nil) return;
	cout<<"(";
	showtree(x->lch);
	cout<<x->num;
	showtree(x->rch);
	cout<<")";
}
int N,l[200000],r[200000];
int main(){
	nil=news(0,0,0);
	nil->lch=nil->rch=nil;
	cin>>N;
	rep(i,N) cin>>l[i];
	rep(i,N) cin>>r[i];
	int x=0;
	root[0]=nil;
	rep(i,N){
//		showtree(root[i]);puts("");
		root[i+1]=fix(root[i],x,0,N);
		int cnt=getnum(l[i],r[i],0,N,root[x+1]);
		x=((long long)cnt*l[i]+r[i])%(i+2);
	}
	cout<<x<<endl;
}

2問目:

//http://codeforces.com/contest/484/problem/E
#include <iostream>
#include <cstdio>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <deque>
#include <stack>
#include <algorithm>
#include <cstring>
#include <functional>
#include <cmath>
using namespace std;
#define rep(i,n) for(int i=0;i<(n);++i)
#define rep1(i,n) for(int i=1;i<=(n);++i)
#define all(c) (c).begin(),(c).end()
#define fs first
#define sc second
#define pb push_back
#define show(x) cout << #x << " " << x << endl
#define chmax(x,y) x=max(x,y)
const int MX=2000000;
struct node{
	node *lch,*rch;
	int l,r,mx;
	bool all;
	node(node *lch_,node *rch_,int l_,int r_,int mx_,bool all_){
		lch=lch_,rch=rch_,l=l_,r=r_,mx=mx_,all=all_;
	}
	node(node *lch_,node *rch_){
		lch=lch_,rch=rch_;
		all=lch->all & rch->all;
		if(lch->all) l=lch->mx+rch->l;
		else l=lch->l;
		if(rch->all) r=rch->mx+lch->r;
		else r=rch->r;
		mx=max(lch->mx,rch->mx);
		chmax(mx,lch->r+rch->l);
	}
	node(){}
};
node pool[MX];
node *root[100001];
node *nil;
int sit;
node *news(node *lch,node *rch,int l,int r,int mx,bool all){
	return &(pool[sit++]=node(lch,rch,l,r,mx,all));
}
node *news(node *lch,node *rch){
	return &(pool[sit++]=node(lch,rch));
}
node *fix(node *x,int pos,int l,int r){		//0->1
//	printf("pos,l,r=(%d,%d,%d)\n",pos,l,r);
	if(r-l==1){
		return news(nil,nil,1,1,1,1);
	}
	int m=(l+r)/2;
	if(pos<m){
		return news(fix(x->lch,pos,l,m),x->rch);
	}else{
		return news(x->lch,fix(x->rch,pos,m,r));
	}
}
typedef pair<int,int> P;
int mx,cont;
void query(int a,int b,int l,int r,node *x){
	if(x==nil){
		cont=0;
		return;
	}
	if(b<=l||r<=a) return;
	if(a<=l&&r<=b){
		chmax(mx,x->mx);
		chmax(mx,cont+x->l);
		if(x->all) cont+=r-l;
		else cont=x->r;
		return;
	}
	int m=(l+r)/2;
	query(a,b,l,m,x->lch);
	query(a,b,m,r,x->rch);
}
void showtree(node *x){
	if(x==nil) return;
	cout<<"(";
	showtree(x->lch);
	printf("<%d,%d,%d>",x->l,x->mx,x->r);
	showtree(x->rch);
	cout<<")";
}
int n,q,x[100000];
vector<int> ash;
vector<int> val[100000];
int main(){
	nil=news(0,0,0,0,0,0);
	nil->lch=nil->rch=nil;
	root[0]=nil;
	cin>>n;
	rep(i,n) cin>>x[i];
	rep(i,n) ash.pb(x[i]);
	sort(all(ash));
	ash.erase(unique(all(ash)),ash.end());
	int m=ash.size();
	rep(i,n) x[i]=lower_bound(all(ash),x[i])-ash.begin();
	rep(i,n) val[x[i]].pb(i);
	rep(i,m){
		root[i+1]=root[i];
		for(int v:val[m-1-i]) root[i+1]=fix(root[i+1],v,0,n);
	}
/*	rep(i,m+1){
		printf("ver i:%d\n",i);
		showtree(root[i]);puts("");
	}*/
	cin>>q;
	rep(i,q){
		int l,r,w;
		cin>>l>>r>>w;
		l--;
		int ub=m,lb=0;
		while(ub-lb>1){
			mx=cont=0;
			int mid=(ub+lb)/2;
			query(l,r,0,n,root[mid]);
//			printf("ver,mx=(%d,%d)\n",mid,mx);
			if(w<=mx) ub=mid;
			else lb=mid;
		}
		cout<<ash[m-ub]<<endl;
	}
}