Dyd's Blog

He who has a strong enough why can bear almost any how.

CF1530F Bingo

最讨厌期望了

Bingo

题意

给 $n \times n$ 的矩阵,每个点有 $p_{i, j} \times 10^{-4}$ 的概率为 $1$ ,否则为 $0$ ,求存在某一行或者一列或者一条对角线全为 $1$ 的概率, $n \le 21$

做题

先想暴力,设集合 $l_1, l_2, …, l_n$ 表示每一行的点集, $l_{n + 1}, …, l_{2n}$ 表示每一列的点集, $l_{2n + 1}, l_{2n + 2}$ 表示对角线,令 $P(l_i)$ 表示点集 $l_i$ 全为 $1$ 的概率, $P(\overline{l_i})$ 表示点集 $l_i$ 不全为 $1$ 的概率

明显可以用容斥 $O(2^{2n + 2})$ 暴力统计答案,但当然 TLE

先来看几条正确性显然的性质:

  1. $P(l_i) + P(\overline{l_i}) = 1$
  2. $P(A \wedge B) = P(B) P(A \mid B)$ ,其中 $A \mid B$ 表示在事件 $B$ 发生的条件下发生事件 $A$
  3. $P(\overline{A} \wedge B) + P (A \wedge B) = P(B)$

发现我们要求的就是 $P(l_1 \vee l_2 \vee … \vee l_{2n + 2}) = 1 - P(\overline{l_1} \wedge \overline{l_2} \wedge … \wedge \overline{l_{2n + 2}})$ ,然后变形:
$$
\begin{aligned}
& P(\overline{l_1} \wedge \overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) + P(l_1 \wedge \overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) = P(\overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) & (1) \\
& P(l_1 \wedge \overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) = P(l_1) P(\overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) & (2) \\
& 由 (1) (2) 得: \\
& P(\overline{l_1} \wedge \overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) = P(\overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) - P(l_1) P(\overline{l_2} \wedge … \wedge \overline{l_{2n + 2}}) & (3) \\
\end{aligned}
$$
可以容斥来解释

考虑 dp ,定义 $d(i, S)$ 表示 $P(\overline{l_i} \wedge … \wedge \overline{l_{2n + 2}} \mid l_{s_1} \wedge … \wedge l_{s_k})$ ,即直线 $s \in S$ 全为 $1$ 而直线 $l_{i \sim 2n + 2}$ 不全为 $1$ 的概率,那么 $(3)$ 式可化为:
$$
\begin{aligned}
& d(i, S) = d(i + 1, S) - d(i + 1, S \vee \{i\}) P(l_i \mid l_{s_1} \wedge … \wedge l_{s_k}) & (4) \\
\end{aligned}
$$
观察这个式子, $P(l_i \mid l_{s_1} \wedge … \wedge l_{s_k})$ 可以 $O(n)$ 求,初始化为 $d(2n + 3, S) = 1$ ,答案就是 $1 - d(1, 0)$ ,时间为 $O(n \times (2n + 2) \times 2^{2n + 2})$ ,这不爆炸,还不如暴力

看来还得加点容斥,我们暴力枚举行的情况,有 $2^n$ 种(要么 $l_i$ 要么 $\overline{l_i}$ ),然后考虑列(下面的 $i$ 就只代表列了), $(4)$ 式可化为:
$$
\begin{aligned}
& d(i, S) = d(i + 1, S) - d(i + 1, S) P(l_i \mid l_{s_1} \wedge … \wedge l_{s_k}) \\
\end{aligned}
$$
这是因为我们现在只考虑列,把当前列加入 $S$ 对后面的列没有影响,所有可以干脆不加,则 $S$ 中始终只有我们枚举出的行(要注意此时求出的 $d$ 已经不一样了,因为行我们是枚举的,其概率还没加到 $d$ 中),这样,在枚举了行后,列可以 $O(n^2)$ 计算,这里有一个 $n$ 是求 $P(l_i \mid l_{s_1} \wedge … \wedge l_{s_k})$ 的,这可以预处理成 $O(1)$ ,具体地,设 $mul(i, S)$ 表示“对于第 $i$ 列,行选则情况为 $S$ 的时候,这些行与第 $i$ 列的相交的格子的乘积”,且设 $U$ 为行集合的全集,则有:
$$
\begin{aligned}
& d(i, S) = d(i + 1, S) - d(i + 1, S) \times mul(i, U - S) \\
\end{aligned}
$$
但上面说了,时求出的 $d$ 已经不一样了,为了加上行选择的情况的概率,还要乘一个 $mul(i, S)$ ,同时,发现第一维完全可以省略,即:
$$
\begin{aligned}
& d(S) = (d(S) - d(S) \times mul(i, U - S)) \times nul(i, S) \\
& d(S) = (1 - mul(i, U - S)) \times mul(i, S) \times d(S) & (5) \\
\end{aligned}
$$
这样,列的 dp 就是 $O(n)$ 的了

然后我们再明确一下行的影响,行是枚举的,计算时要用容斥,就是: $1 - d(至少 1 行全为 1) + d(至少 2 行全为 1) - …$ , dp 贡献时加个系数即可

最后考虑一下对角线,由于只有两条,可以和行一样枚举它们的情况,把它当成一种特殊的行(具体见代码)

考虑时间,枚举容斥是 $O(2^{n + 2})$ , dp 是 $O(n)$ 的,最后就是 $O(n 2^{n + 2})$

ps:模数 $31607$ 是质数, $10^4$ 在其意义下的逆元为 $3973$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <bits/stdc++.h>
#define STC static
#define LL long long
#define lowbit(x) ((x) & -(x))
using namespace std;
const int N = 21 + 5, P = 31607, INV = 3973, NN = (1 << 21) + 5;
int n, nn, p[N][N], ans = 0;
int lg2[NN], mul[N][NN];
int d[NN];
void prev()
{
lg2[1] = 0;
for (int i = 2; i <= nn; ++i) lg2[i] = lg2[i >> 1] + 1;
for (int i = 1, s, j; i <= n; ++i)
for (mul[i][0] = 1, s = 1; s < nn; ++s) mul[i][s] = (LL)mul[i][s ^ (j = lowbit(s))] * p[lg2[j] + 1][i]% P;
}
int cntbit(int x)
{
int res;
for (res = 1; x; x &= (x - 1), ++res);
return res;
}
int main()
{
scanf("%d", &n), nn = 1 << n;
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j) scanf("%d", &p[i][j]), p[i][j] = (LL)p[i][j] * INV % P;
prev();
for (int t = 0; t <= 3; ++t) //t枚举对角线的情况
{
for (int s = 0; s < nn; ++s) d[s] = (cntbit(s) + cntbit(t)) & 1 ? -1 : 1; //容斥系数
for (int i = 1; i <= n; ++i)
for (int s = 0, _s; s < nn; ++s)
{
//加上对角线
_s = s;
if (t & 1) _s |= 1 << (i - 1);
if (t & 2) _s |= 1 << (n - i);
d[s] = (LL)d[s] * mul[i][_s] % P * (1 - mul[i][(nn - 1) ^ _s] + P) % P;
}
for (int s = 0; s < nn; ++s) ans = (ans + d[s]) % P;
}
printf("%d\n", (1 - ans + P) % P);
return 0;
}