点一成零

并查集 组合计数

Posted by BY 水蓝 on February 20, 2021

链接:https://ac.nowcoder.com/acm/contest/9981/D

题目

牛牛拿到了一个n*n的方阵,每个格子上面有一个数字:0或1 行和列的编号都是从0到n-1 现在牛牛每次操作可以点击一个写着1的格子,将这个格子所在的1连通块全部变成0。 牛牛想知道,自己有多少种不同的方案,可以把全部格子的1都变成0? 所谓连通块,是指方阵中的两个正方形共用一条边,即(x,y)和以下4个坐标的数是连通的:(x-1,y)、(x+1,y)、(x,y-1)、(x,y+1) 这个问题对于牛牛来说可能太简单了。于是他将这个问题变得更加复杂: 他会选择一个格子,将这个格子上的数字修改成1(如果本来就是1,那么不进行任何改变),再去考虑“点一成零”的方案数。 牛牛想知道,每次“将某个格子修改成1”之后,“把全部格子的1都变成0”的方案数量。 ps:请注意,每次“将某个格子修改成1”之后,状态会保留到接下来的询问。具体请参考样例描述。 由于方案数可能过大,请对1e9+7取模

思路

假设一共用 m 个 1 的连通块,每个块的大小为siz[i]

每个连通块的点击顺序随意,所以有n!中方案,每个连通块需要选择其中一个进行点击,所以总方案数为 考虑把 0 变成 1 的操作,分为四种情况:

  1. 如果本来就是 1 那么无事发生

  2. 单独的一个 1,就相当于多加了一个连通块

  3. 某个连通块变大

  4. 两个连通块连在了一起,即连通块总数先减二,再增加一个大的连通块

具体做法

用并查集进行统计,设初始结果为ans

情况 2 : ans = ans * (m+1) % mod

情况 4 : 假设是 x 和 y 两个连通块合并了 ans = (siz[x] + siz[y]) * ans / (m * siz[x] * siz[y]) (这个地方要用到费马小定理,如果模数mod是质数,那么一个数a的0乘法逆元就是a的(mod-2)次方)

情况 3 : 我的做法是,只要不是情况 1,就按情况 2 处理,然后扫描周边,检查是否会出现情况 4,在这个过程中,就会处理好这三种情况

代码

#include <cstdio>
#include <iostream>
using namespace std;
#define ll long long

const int maxn = 505, mod = 1e9 + 7;

char mp[maxn][maxn];
int n, k, fa[maxn*maxn], siz[maxn*maxn];
int dx[5] = {0, 0, 1, -1}, dy[5] = {1, -1, 0, 0};

int Find(int x)
{
    return fa[x] == x ? x : fa[x] = Find(fa[x]);
}

inline void merge(int x, int y)
{
    int fx = Find(x), fy = Find(y);
    if(fx != fy) {
        siz[fx] += siz[fy];
        fa[fy] = fx;
    }
    return ;
}

ll Pow(ll bas,ll pw)
{
    ll tmp = 1;
    while(pw) {
        if(pw & 1)
            tmp = tmp * bas % mod;
        bas = bas * bas % mod;
        pw >>= 1;
    }
    return tmp;
}

int main()
{
    scanf("%d", &n);
    for(int i = 1; i <= n; ++i)
        scanf("%s" ,mp[i] + 1);
    for(int i = 0; i <= (n+1) * (n+1); ++i)
        fa[i] = i, siz[i] = 1;
    for(int i = 1; i <= n; ++i) {
        for (int j = 1; j <= n; ++j) {
            if (mp[i][j] == '1') {
                for (int k = 0; k <= 3; ++k) {
                    int nx = i + dx[k], ny = j + dy[k];
                    if (mp[nx][ny] == '1')
                        merge(j * (n + 1) + i, ny * (n + 1) + nx);
                }
            }
        }
    }
    int blockNum = 0;
    ll ans = 1;
    for(int i = 1; i <= n; ++i) {
        for (int j = 1; j <= n; ++j) {
            if (mp[i][j] == '1' && fa[j*(n+1)+i] == j*(n+1)+i) {
                ++blockNum;
                ans = (ans * siz[j*(n+1) + i]) % mod;
            }
        }
    }
    for(int i = 1; i <= blockNum; ++i)
        ans = ans * i % mod;
    scanf("%d", &k);
    while(k--) {
        int x , y;
        scanf("%d%d", &x, &y);
        ++x; ++y;
        if(mp[x][y] == '1') {
            printf("%lld\n", ans);
            continue;
        }
        mp[x][y] = '1';
        ++blockNum;
        ans = ans * blockNum % mod;
        for(int i = 0; i <= 3; ++i) {
            int nx = x + dx[i], ny = y + dy[i];
            if(mp[nx][ny] == '1') {
                int fx = Find(y*(n+1) + x), fnx = Find(ny*(n+1) + nx);
                if(fx != fnx) {
                    ans = ans * Pow(blockNum, mod - 2) % mod;
                    ans = ans * Pow(siz[fx], mod - 2) % mod;
                    ans = ans * Pow(siz[fnx], mod - 2) % mod;
                    ans = ans * (siz[fnx] + siz[fx]) % mod;
                    --blockNum;
                    merge(fx, fnx);
                }
            }
        }
        printf("%lld\n", ans);
    }
    return 0;
}