Dyd's Blog

He who has a strong enough why can bear almost any how.

UOJ310黎明前的巧克力

FWT + 解方程

黎明前的巧克力

思路

设集合 $S$ 中的数异或和为 $0$ ,那么显然集合 $S$ 的任意一种划分方案都是答案,即答案为 $2^{\mid S \mid}$ ,考虑 FWT ,对于每个数 $a_i$ ,构造幂级数 $f_i(x) = (1 + 2x^{a_i})$ (不选为 $1$ ,选了集合大小 $+1$ ,在指数上就是 $\times 2$ ),那么定义乘法为异或卷积,记 $A = \max(a_i), F = \prod_{i = 1}^n f_i$ , $F[0]$ 显然就是答案(其实要 $-1$ 因为不能划分成空集),但直接暴力 FWT 时间为 $O(nA \log A)$ 无法接受

这里引入一个技巧,考虑到 $f_i$ 其实只要两项,它 FWT 变换后的数组 $f’_i$ 其实可以手玩,即 $f’i = \sum{s = 0}^A (1 + (-1)^{\mid a_i \& s \mid} \times 2)x^s$ $f’i(x) = \sum{s = 0}^A (1 + (-1)^{\mid a_i & s \mid} \times 2)x^s$,不难发现 $f’i$ 的每一位其实只有 $-1$ 和 $3$ 两种取值,把它们连乘,得到 $F’(x) = \sum{s = 0}^A (-1)^{n - t_s} \times 3^t_s x^s$ ,那么我们只要得到每一个 $t_s$ 即可得到 $F’$

考虑幂级数 $G = \sum_{i = 1}^n f_i$ ,由于 FWT 的线性,它的 FWT 变换为 $G’(x) = \sum_{i = 1}^n f’i(x) = \sum{s = 0}^A (-1 \times (n - t_s) + 3 \times t_s) x^s$ ,那么只要求得 $G’$ 即可解出 $t_s$ ,得到 $F’$ 后再 IFWT 回去即得 $F$

代码

注意由于 $tot$ 是 $2$ 的次幂,数组要开两倍

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
using LL = long long;
const int N = 2e6 + 100, P = 998244353, iv2 = 499122177, iv4 = 748683265;
const int MOD = 998244353;
int n, a[N], mx, tot;
int f[N], pw3[N];
void adj(int &x){ x += (x >> 31) & P; }
void fwtor(int x[], int len)
{
for (int i = 2, t1, t2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k)
{
t1 = x[k], t2 = x[k + mid];
adj(x[k] = t1 + t2 - P), adj(x[k + mid] = t1 - t2);
}
}
void ifwtor(int x[], int len)
{
for (int i = 2, t1, t2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k)
{
t1 = x[k], t2 = x[k + mid];
adj(x[k] = t1 + t2 - P), adj(x[k + mid] = t1 - t2);
x[k] = LL(x[k]) * iv2 % P, x[k + mid] = LL(x[k + mid]) * iv2 % P;
}
}
int main()
{
scanf("%d", &n);
pw3[0] = 1;
for (int i = 1; i <= n; ++i) pw3[i] = LL(pw3[i - 1]) * 3 % P;
for (int i = 1; i <= n; ++i)
{
scanf("%d", &a[i]);
f[0] += 1, f[a[i]] += 2;
mx = std::max(mx, a[i]);
}
tot = 1;
while (tot <= mx) tot <<= 1;
fwtor(f, tot);
for (int i = 0, t; i < tot; ++i)
{
t = LL(f[i] + n) * iv4 % P;
adj(f[i] = pw3[t] * ((n - t) & 1 ? -1 : 1));
}
ifwtor(f, tot);
adj(f[0] -= 1);
printf("%d\n", f[0]);
return 0;
}