并查集专题整理

kuangbin专题

模板

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}
void join(int x, int y) {
    int fx = find(x), fy = find(y);
    if (fx != fy) fa[fx] = fy;
}

关于并查集的一点心得

大家都说带权并查集的起点是食物链( POJ - 1182 ),但是可能我比较笨,没有从这道题里领会到很深的东西,一直到做到虫子的一生( POJ - 2492 )才有点领悟到,也多亏了在做Navigation Nightmare( POJ - 1984 )时进巨的点拨。

并查集

并查集就是一棵树,与之相关的最小生成树的Kruskal,从一个环图中找一棵由最短的边组成的树,只要一棵,也就是说每个点都必须至少有一条路可达,这就是最小生成树。
把关联的结点用父子关系链接,这样,有关的结点的根节点就会是同一个,可以判断出是否在一棵树上。同时可以判环,如果两个点已知相连,并且有同一个根节点,那就是有环。
并查集的两个常用操作,一个是初始化,所有的点的父节点都是自己,一个是查找当前结点的根节点,一个是合并两棵树。查找根节点一般用递归,找到最上方的结点,其父节点就是自身。为了提高时间效率,查找时加上压缩路径,即每个点的父节点都直接改变成根节点。合并树操作在带权并查集中发挥很大作用。

并查集详解这篇文章讲得浅显易懂,很不错。

带权并查集

而带权并查集,其中的权需要转化成子节点与父节点之间的联系。这样向上查找时就能发现父节点和子节点之间的关系,以此来进行计算。
带权并查集的压缩路径方法是,在递归向上查找的同时,因为递归是直接到达最深处然后向上回溯的,所以只需要对每个点都做一次累加,这样回到原来的位置时就是全部的累加。合并树操作各有不同,主要是创建父节点的操作。
举个例子,虫子的一生( POJ - 2492 )中用权数组表示两个虫子的性别关系,更新时就只要考虑一下同性还是异性即可。

E - 食物链

POJ - 1182

题解

带权并查集。
有三种动物,ABC,A吃B,B吃C,C吃A。现在给出一些动物之间的关系,1表示同类,2表示X吃Y。判断有多少关系是错误的。
这道题算是带权并查集的入门题。从中可以领悟到带权并查集的套路,但是我没有……果然我太蠢了?
开两个数组,一个并查集,一个权。权数组储存和父结点之间的关系。我们选择0代表同类,1代表子被父吃,2代表子吃父。
这样,在查找根节点时压缩路径,只要让权值加上父结点的权值然后模3即可。经过简单的枚举计算,合并操作只要v[fy] = (v[x] - v[y] + d + 2) % 3;即可。当然,更新xy都是可以的。
有趣的是,我第一遍是看着题解做的,第二遍自己做,两遍的代码几乎一模一样。

代码

#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
const int maxn = 50005;
// 0 代表同类,1 代表被吃,2代表吃
int fa[maxn], v[maxn];
void init(int n){
    for (int i = 0; i <= n; ++i)
        fa[i] = i;
    memset(v, 0, sizeof(v));
}

int find(int x) {
    if (x != fa[x]){
        int fx = find(fa[x]);
        v[x] = (v[x] + v[fa[x]]) % 3;
        fa[x] = fx;
    }
    return fa[x];
}

bool join(int x, int y, int d){
    int fx = find(x), fy = find(y);
    if (fx == fy){
        if ((v[y] - v[x] + 3) % 3 != d - 1)
            return true;
        return false;
    }
    fa[fy] = fx;
    v[fy] = (v[x] - v[y] + d + 2) % 3;
    return false;
}

int main(){
    int n, k, cnt = 0;
    scanf("%d%d", &n, &k);
    init(n);
    while(k--){
        int d, x, y;
        scanf("%d%d%d", &d, &x, &y);
        if (x > n || y > n || (x == y && d == 2) || join(x, y, d))
            ++cnt;
    }
    printf("%d\n", cnt);
    return 0;
}

G - Supermarket

POJ - 1456

题解

贪心。
给出一堆商品的收益和必须要卖出的deadline,求能获得的最大收益。
WA了好几次,本来以为自己不会出这种小bug了,看起来还是编码能力有待提高。也不是很懂为什么一道简单贪心放在并查集上。
先把所有商品的两个属性封装,然后排序,从头开始找最大的,放在能放的最后面的位置。

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 10005;

struct pord{
    int p, d;
}mp[maxn];
int n, vis[maxn];

bool cmp(pord a, pord b){
    return a.p > b.p;
}

int main(){
    while (~scanf("%d", &n)){
        for (int i = 1; i <= n; ++i)
            scanf("%d%d", &mp[i].p, &mp[i].d);
        sort(mp+1, mp+n+1, cmp);
        memset(vis, 0, sizeof(vis));
        int ans = 0;
        for (int i = 1; i <= n; ++i)
            for (int j = mp[i].d; j > 0; --j)
                if (!vis[j]){
                    vis[j] = 1;
                    ans += mp[i].p;
                    break;
                }
        printf("%d\n", ans);
    }
    return 0;
}

H - Parity game

POJ - 1733

题解

并查集+离散化。
有一个01串,1e9位。给出一些子串,和他们中有奇数个还是偶数个1,求这些中有几个是对的。
因为长度太长所以开不下这么多数组,而子串只有五千个,所以需要哈希一下。只要存最多一万个数字就可以。为了避免不在一起的相邻所以多存一位,就是两万位。然后用一个数组存当前下标之前有奇数还是偶数个1,最后用并查集求解。

代码

#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 100005;

int fa[maxn], val[maxn], n, m;
int hashSet[maxn];

struct {
    int u, v, w;
}node[maxn];

int find(int n){
    int k = fa[n];
    if(fa[n] != n){
        fa[n] = find(fa[n]);
        val[n] = (val[n] + val[k])%2;
    }
    return fa[n];
}

void init(){
    for (int i = 0; i <= n; ++i)
        val[i] = 0, fa[i] = i;
}

int main(){
    while (~scanf("%d", &n)){
        int i, k = 0;

        //init放在这里会RE

        scanf("%d", &m);
        for (i = 0; i < m; ++i){
            char s[5];
            scanf("%d%d%s", &node[i].u, &node[i].v, s);
            node[i].w = s[0] == 'e'? 0:1;

            hashSet[k++] = node[i].u - 1;
            hashSet[k++] = node[i].u;
            hashSet[k++] = node[i].v - 1;
            hashSet[k++] = node[i].v;
        }
        hashSet[k++] = n;
        hashSet[k++] = n - 1;

        sort(hashSet, hashSet+k);
        n = (int)(unique(hashSet, hashSet+k) - hashSet);

        init();
        //init放这里就AC

        for (i = 0; i < m; ++i){
            int u = (int)(lower_bound(hashSet, hashSet+n, node[i].u-1) - hashSet);
            int v = (int)(lower_bound(hashSet, hashSet+n, node[i].v) - hashSet);

            int fu = find(u), fv = find(v);

            if (fu == fv && (val[u] + node[i].w)%2 != val[v])
                break;
            if (fu < fv){
                fa[fv] = fu;
                val[fv] = (val[u] + node[i].w - val[v] + 2) % 2;
            }
            if (fu > fv){
                fa[fu] = fv;
                val[fu] = (val[v] - node[i].w - val[u] + 2) % 2;
            }
        }
        printf("%d\n", i);
    }
    return 0;
}

I - Navigation Nightmare

POJ - 1984

题解

带权并查集。
题目给出一些路,路有起点、终点、长度、方向四个属性。要求回答提问的两个点之间的曼哈顿距离(x轴距离+y轴距离),注意提问给出了提问的时间,此时的地图不一定完整(不是所有的路都能用)。
卡了很久。都是TLE,因为一个地方没有把二层循环改为O(nlogn)的写法。
col数组存储当前点相对于父节点的偏移量,ed数组存储边的四个属性,q数组存储问题。之所以用数组是为了避免因为问题的顺序是乱的,用cmp函数排个序。
find查找函数中有路径压缩,把树上的所有点的偏移量直接改为到根节点的偏移量。
join合并函数根据路的方向和长度来更新偏移量。
最后搜索,如果不在同一棵树上就不成立,否则维护偏移量。

代码

#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 40005;

int fa[maxn], ans[maxn];
int n, m;

struct{
    int x, y;
}col[maxn];

struct {
    int x, y, d;
    char s[2];
}ed[maxn];

struct que{
    int a, b, idx, n;
}q[maxn];

bool cmp(que a, que b) {
    return a.idx < b.idx;
}

int find(int n) {
    if (n != fa[n]) {
        int k = fa[n];
        fa[n] = find(fa[n]);
        col[n].x += col[k].x;
        col[n].y += col[k].y;
    }
    return fa[n];
}

void join(int u, int v, int i) {
    int fu = find(u);
    int fv = find(v);
    if (fu == fv) return;
    fa[fv] = fu;

    int rx = col[v].x;
    int ry = col[v].y;
    switch (ed[i].s[0]){
        case 'E': 
            col[fv].x = col[u].x + ed[i].d - col[v].x;
            col[fv].y = col[u].y - col[v].y;
            break;
        case 'W': 
            col[fv].x = col[u].x - ed[i].d - col[v].x;
            col[fv].y = col[u].y - col[v].y;
            break;
        case 'S':
            col[fv].x = col[u].x - col[v].x;
            col[fv].y = col[u].y - ed[i].d - col[v].y;
            break;
        case 'N':
            col[fv].x = col[u].x - col[v].x;
            col[fv].y = col[u].y + ed[i].d - col[v].y;
            break;
    }
}

int main() {
    while (~scanf("%d%d", &n, &m)) {
        for (int i = 0; i <= n; ++i) fa[i] = i;

        for (int i = 1; i <= m; ++i)
            scanf("%d%d%d%s", &ed[i].x, &ed[i].y, &ed[i].d, ed[i].s);

        int k;
        scanf("%d", &k);
        for (int i = 1; i <= k; ++i) {
            scanf("%d%d%d", &q[i].a, &q[i].b, &q[i].idx); // input questions
            q[i].n = i;
        }

        sort(q + 1, q + k + 1, cmp);
        for (int i = 1; i <= k; ++i) {
            int cur = i == 1 ? 1 : q[i - 1].idx + 1;
            for (int j = cur; j <= q[i].idx; ++j) {
                join(ed[j].x, ed[j].y, j);
            }

            if (find(q[i].a) != find(q[i].b))
                ans[q[i].n] = -1;
            else
                ans[q[i].n] = abs(col[q[i].a].x - col[q[i].b].x) + abs(col[q[i].a].y - col[q[i].b].y);
        }

        for (int i = 1; i <= k; ++i)
            printf("%d\n", ans[i]);
    }
    return 0;
}

J - A Bug's Life

POJ - 2492

题解

带权并查集。
给出n条虫子的m条交配记录。假设虫子有两个性别,都是异性恋。根据记录判断假设是否正确。
简单的带权并查集应用。难点主要在于如何选择权数组的形式。
这里用权数组re来表示当前虫子和上一级虫子的关系,0为同性,1为异性。要注意权数组意义,以及如何在合并树操作时处理权数组。

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 2005;

int fa[maxn], re[maxn];
int n, m;
bool flag;

void init(){
    for (int i = 0; i <= n; ++i)
        fa[i] = i;
    memset(re, 0, sizeof(re));
    flag = true;
}

int find(int x){
    if (x == fa[x])
        return x;
    int t = find(fa[x]);
    re[x] = (re[x] + re[fa[x]]) % 2;
    fa[x] = t;
    return fa[x];
}

bool join(int x, int y){
    int fx = find(x), fy = find(y);
    if (fx == fy){
        if (re[x] == re[y])
            return false;
        return true;
    }
    fa[fy] = fx;
    re[fy] = (re[y] - re[x] + 1) %2;
    return true;
}

bool solve(){
    int a, b;
    bool f = true;
    scanf("%d%d", &n, &m);
    init();
    for (int i = 1; i <= m; ++i){
        scanf("%d%d", &a, &b);
        if (!join(a, b))
            f = false;
    }
    return f;
}

int main(){
    int t;
    scanf("%d", &t);
    for (int i = 1; i <= t; ++i){
        if (solve())
            printf("Scenario #%d:\nNo suspicious bugs found!\n\n", i);
        else
            printf("Scenario #%d:\nSuspicious bugs found!\n\n", i);
    }
    return 0;
}