整体二分和CDQ分治

VictorWonder posted @ 2014年11月22日 14:47 in Algorithms with tags 整体二分 CDQ分治 , 10644 阅读
我想要学习整体二分和CDQ已经很久了。
虽然Claris表示,这东西考试的时候估计是用不上的,现在很多题目,尤其是高级数据结构题都会强制在线。
但我觉得还是有必要学习一下,坑了很久,终于差不多把基本的用法给掌握了。
(我会告诉你们其实我很多时间都坑在斜率和凸壳上了么……感觉斜率和凸壳各种搞不懂……)
 
整体二分的资料好像不是很多,我在网上找到了一篇不错的资料:http://www.cnblogs.com/zig-zag/archive/2013/04/18/3027707.html
 
来看一道例题:
有n个国家和m个空间站,每个空间站都属于一个国家,一个国家可以有多个空间站,所有空间站按照顺序形成一个环,也就是说,m号空间站和1号空间站相邻。
现在,将会有k场流星雨降临,每一场流星雨都会给区间[li,ri]内的每个空间站带来ai单位的陨石,每个国家都有一个收集陨石的目标pi,即第i个国家需要收集pi单位的陨石。
询问:每个国家最早完成陨石收集目标是在第几场流星雨过后。
数据范围:1<=n,m,k<=300000
 
对于单个查询(假设为第i个国家),我们可以二分k,每次对于一个区间[l,r],手动模拟一下在第mid场流星雨过后,第i个国家一共收集到了多少单位的陨石,如果比pi大,那么答案在[l,mid]范围内,否则答案在[mid+1,r]范围内。
对于多组查询,我们也可以这么做。首先,我们需要用一个列表id[]记录所有查询的编号,刚开始的时候,id[]自然是递增的.同时,我们用一个数组cur[i]记录下,第i个国家在l-1场流星雨过后,收集到的陨石的数目。
主过程为void solve(int head,int tail,int l,int r),表示对于id[head]到id[tail]的所有询问,在[l,r]范围内查询答案,通过上一层的操作,我们保证id[head]到id[tail]的所有询问的答案都在[l,r]范围内。
首先,我们先模拟[l,mid]这么多次操作(在询问重新划分之后,必须要再次模拟,将数组清空),用树状数组或者是线段树计算出在[l,mid]场流星雨之后,每个空间站收集到的陨石的数目。
然后我们查询,每个国家收集到的陨石的数目,要注意的是,我们需要用链表储存每个国家对应的空间站,并且一一枚举,用tmp[id[i]]表示国家id[i]收集到的陨石的数目。
那么从[1,mid]这么多次操作之后,国家id[i]收集到的陨石数目就是tmp[id[i]]+cur[id[i]],如果tmp[id[i]]+cur[id[i]]>p[id[i]],那么表明对于国家id[i],其答案在[l,mid]这个范围内,否则其答案在[mid+1,r]范围内,并将tmp[id[i]]累加到cur[id[i]]上。
还有一个坑点是,tmp[id[i]]可能很大,会爆掉long long,所以如果枚举一个国家的所有空间站的时候,发现tmp[id[i]]已经大于p[id[i]]了,那么就break好了,不然会出错。
因为可能会出现怎么也无法满足的情况,所以我们需要多增加一场流星雨,这场流星雨的数量为infi,保证能够让所有国家都满足要求,那么最后,对于所有答案为k+1的询问,输出NIE就行了。
 
总的来说,整体二分就是将所有询问一起二分,然后获得每个询问的答案。
 
CDQ相比整体二分略有不同,整体二分是对答案进行二分,而CDQ分治则是对于所有操作进行二分。
一开始的时候,一般都需要先排序一下,将所有的操作按照一定的性质排起来(Cash一题是按-a[i]/b[i]从大到小排序,Mokia一题是按照x坐标从小到大排序),否则的话和直接暴力没有区别。
接着就是对所有操作进行二分,对于一个区间[l,r],我们首先要对[l,mid]范围内的操作进行查询,所以需要将[l,r]范围内的操作重排,保证[l,mid]范围内的操作的id都<=mid,然后再对[l,mid]范围内的操作进行递归求解。
递归完之后,[l,mid]范围内的所有操作都已经按照先后的顺序排序好了,但是[mid+1,r]范围内的操作还是按照某个关键字排序。我们根据[l,mid]范围内的操作更新[mid+1,r]范围内的操作(在Cash一题中是更新它们的f[]数组,在Mokia中则要和整体二分差不多,先用[l,mid]范围内的操作模拟,再更新[mid+1,r]范围内所有查询的答案)。
更新完之后,我们再递归求解[mid+1,r]。
因为大神CDQ的论文已经讲得很详细了,网上关于CDQ分治的资料也挺多的,就不再讲了,感觉这些神奇的算法还是挺好用的,尤其是在自己懒的时候(许多高级数据结构题的代码量真是醉了……)。
 
#include <cstdio>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <utility>
#include <bitset>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <ctime>
#include <cmath>
#include <list>
#include <map>
#include <set>
using namespace std;
typedef long long ll;
typedef double D;
typedef pair<int,int> pr;
const int infi=1000000010;
const int N=500010;
const int M=2000100;
struct node{int x,y,z;}p[N];
int n,m,k,c[N],id[N],ans[N];
int g[N],to[N],nxt[N],tot;
int tol[N],tor[N];
ll h[N],tmp[N],cur[N];
void read(int &a) {
    char ch; while (!((ch=getchar())>='0'&&ch<='9'));
    a=ch-'0'; while ((ch=getchar())>='0'&&ch<='9') (a*=10)+=ch-'0';
}
void add(int pos,int x) {while (pos<=m) h[pos]+=x,pos+=(pos&-pos);}
void adddt(int x,int y,int z) {add(x,z); add(y+1,-z);}
ll sum(int pos) {ll t=0; while (pos>0) t+=h[pos],pos-=(pos&-pos);return t;}
void addop(int x,int y,int z,int i) {p[i].x=x; p[i].y=y; p[i].z=z;}
void addpt(int x,int y) {to[++tot]=y; nxt[tot]=g[x]; g[x]=tot;}
void solve(int head,int tail,int l,int r) {
    if (head>tail) return;
    int i,k,x,mid=(l+r)>>1,lnum=0,rnum=0;
    if (l==r) {
        for (i=head;i<=tail;i++) ans[id[i]]=l;
        return;
    }
    for (i=l;i<=mid;i++) {
        if (p[i].x<=p[i].y) adddt(p[i].x,p[i].y,p[i].z);
        else adddt(p[i].x,m,p[i].z),adddt(1,p[i].y,p[i].z);
    }
    for (i=head;i<=tail;i++) {
        tmp[id[i]]=0;
        for (k=g[id[i]];k;k=nxt[k]) {
            tmp[id[i]]+=sum(to[k]);
            if (tmp[id[i]]+cur[id[i]]>c[id[i]]) break;
        }
        if (cur[id[i]]+tmp[id[i]]>=c[id[i]]) tol[++lnum]=id[i];
        else tor[++rnum]=id[i],cur[id[i]]+=tmp[id[i]];
    }
    for (i=l;i<=mid;i++) {
        if (p[i].x<=p[i].y) adddt(p[i].x,p[i].y,-p[i].z);
        else adddt(p[i].x,m,-p[i].z),adddt(1,p[i].y,-p[i].z);
    }
    for (i=0;i<lnum;i++) id[head+i]=tol[i+1];
    for (i=0;i<rnum;i++) id[head+lnum+i]=tor[i+1];
    solve(head,head+lnum-1,l,mid);
    solve(head+lnum,tail,mid+1,r);
}
int main() {
    int i,x,y,z;
    scanf("%d%d",&n,&m);
    for (i=1;i<=m;i++) {
        scanf("%d",&x);
        addpt(x,i);
    }
    for (i=1;i<=n;i++) {
        scanf("%d",&c[i]);
        id[i]=i;
    }
    scanf("%d",&k);
    for (i=1;i<=k;i++) {
        scanf("%d%d%d",&x,&y,&z);
        addop(x,y,z,i);
    }
    addop(1,m,infi,++k);
    solve(1,n,1,k);
    for (i=1;i<=n;i++) if (ans[i]!=k) printf("%d\n",ans[i]);
    else puts("NIE");
    return 0;
}
#include <cstdio>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <utility>
#include <bitset>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <ctime>
#include <cmath>
#include <list>
#include <map>
#include <set>
using namespace std;
typedef long long ll;
typedef double D;
typedef pair<int,int> pr;
const int infi=2147483647;
const int N=100010;
struct node{D a,b,r,k;}p[N];
struct func{D x,y;}f[N],v[N];
D ans[N],S;
int n,id[N],s[N];
bool operator < (func a,func b) {return a.x<b.x||a.x==b.x&&a.y<b.y;}
bool cmp(const int &a,const int &b) {return p[a].k>p[b].k;}
D cross(func a,func b,func c) {return (b.x-a.x)*(c.y-b.y)-(b.y-a.y)*(c.x-b.x);}
D calc(func t,int i) {return t.x*p[i].a+t.y*p[i].b;}
void solve(int l,int r,D maxnum) {
    if (l==r) {
        ans[l]=max(ans[l],maxnum);
        f[l].y=ans[l]/(p[l].a*p[l].r+p[l].b);
        f[l].x=f[l].y*p[l].r;
        return;
    }
    int i,mid=(l+r)>>1,t1=l,t2=mid+1,t=0,h=0;
    for (i=l;i<=r;i++) if (id[i]<=mid) s[t1++]=id[i];
    else s[t2++]=id[i];
    memcpy(id+l,s+l,sizeof(int)*(r-l+1));
    solve(l,mid,maxnum);
    for (i=l;i<=mid;v[t++]=f[i++]) while (t&&cross(v[t-2],v[t-1],f[i])>=0) t--;
    for (i=mid+1;i<=r;i++) {
        while (h<t-1&&calc(v[h],id[i])<calc(v[h+1],id[i])) h++;
        ans[id[i]]=max(ans[id[i]],calc(v[h],id[i]));
    }
    solve(mid+1,r,ans[mid]);
    merge(f+l,f+mid+1,f+mid+1,f+r+1,v);
    memcpy(f+l,v,sizeof(func)*(r-l+1));
}
int main() {
    int i;
    scanf("%d%lf",&n,&S);
    for (i=0;i<n;i++) {
        scanf("%lf%lf%lf",&p[i].a,&p[i].b,&p[i].r);
        p[i].k=-p[i].a/p[i].b;
        id[i]=i;
    }
    sort(id,id+n,cmp);
    solve(0,n-1,S);
    printf("%.3lf",ans[n-1]);
    return 0;
}
#include <cstdio>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <utility>
#include <bitset>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <ctime>
#include <cmath>
#include <list>
#include <map>
#include <set>
using namespace std;
typedef long long ll;
typedef double D;
typedef pair<int,int> pr;
const int infi=2147483647;
const int N=500010;
const int M=2000100;
struct node{int op,x,y,z,t,id;}p[N],s[N];
int S,n,num,tot,id[N];
int h[M],ans[N];
bool cmp(const node &a,const node &b) {return a.x<b.x;}
void addquery(int i,int op,int x,int y,int z,int id) {
    p[i].id=i; p[i].op=op; p[i].x=x; p[i].y=y; p[i].z=z; p[i].t=id;
}
void add(int pos,int x) {while (pos<=n) h[pos]+=x,pos+=(pos&(-pos));}
ll sum(int pos) {ll t=0; while (pos>0) t+=h[pos],pos-=(pos&(-pos));return t;}
void solve(int l,int r) {
    if (l==r) return;
    int mid=(l+r)>>1,i,t1=l-1,t2=mid,t=l;
    for (i=l;i<=r;i++) if (p[i].id<=mid) s[++t1]=p[i];
    else s[++t2]=p[i];
    memcpy(p+l,s+l,sizeof(node)*(r-l+1));
    for (i=mid+1;i<=r;i++) if (p[i].op==2) {
        for (;t<=mid&&p[t].x<=p[i].x;t++) if (p[t].op==1) add(p[t].y,p[t].z);
        ans[p[i].t]+=sum(p[i].y)*p[i].z;
    }
    for (i=l;i<t;i++) if (p[i].op==1) add(p[i].y,-p[i].z);
    solve(l,mid); solve(mid+1,r);
}
int main() {
    int i,k,x,y,z,w;
    scanf("%d%d",&S,&n);
    while (~scanf("%d",&k)) {
        if (k==3) break;
        if (k==1) {
            scanf("%d%d%d",&x,&y,&z);
            addquery(++tot,1,x,y,z,0);
        } else {
            scanf("%d%d%d%d",&x,&y,&z,&w); num++;
            addquery(++tot,2,z,w,1,num);
            addquery(++tot,2,x-1,y-1,1,num);
            addquery(++tot,2,x-1,w,-1,num);
            addquery(++tot,2,z,y-1,-1,num);
        }
    }
    sort(p+1,p+1+tot,cmp);
    solve(1,tot);
    for (i=1;i<=num;i++) printf("%d\n",ans[i]);
    return 0;
}

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter