Dyd's Blog

He who has a strong enough why can bear almost any how.

快速沃尔什变换

不要记混就好

快速沃尔什变换

即所谓的FWT

问题

从函数卷积的角度看,多项式乘法就是一个卷积:
$$
(f * g)(n) = \sum_{i + j = n} f(i) * g(j)
$$
而当这种卷积拓展到二进制下(大多用于集合运算)时,我们考虑这个问题:
$$
(f * g)(n) = \sum_{i \circ j = n} f(i) * g(j)
$$
其中 $\circ $ 代表或( $\mid$ )、与( $\&$ )以及异或( $\oplus$ )

此时问题变得麻烦了起来,直接做当然是 $O(n^2)$ 的,其中 $n$ 是多项式长度;于是,类似于 FFT ,我们考虑用分治加速,可以得到 $O(n \log n)$ 的复杂度

思想

FFT 做到了多项式系数表示与点值表示的快速变换,使得我们可以利用点值表示的优势很快得到答案,FWT 也是这么个思想:用一种可逆线性变换(这个“线性”不是指时间,而是指满足线性运算)将原多项式变成某种好计算的表示方法,然后快速得答案;而好消息是, FWT 不必要用复数单位根这种麻烦的东西,只要分治就好了

在开始之前,我们定义 $merge(f, g)$ 表示将两个多项式像字符串拼接一样拼在一起, $f + g$ 表示将两个多项式对于位置相加

或运算

我们从最好理解的或开始,即求 $h(n) = (f * g)(n) = \sum_{i \mid j = n} f(i) * g(j)$

考虑构造 $f’(i) = \sum_{i = i \mid j} f(j)$ ,那么显然 $h’(i) = f’(i) * g’(i)$ ,不妨让 $f$ 的长度为 $n = 2^{bit}$ ,考虑将 $f$ 分成前、后两段 $f_0, f_1$ ,长度都为 $2^{bit - 1}$ ,我们递归构造 $f_0, f_1$ 的 FWT 记作 $f_0’, f_1’$ ,则有
$$
f’ =
\begin{cases}
f & bit = 0 \\
merge(f_0’, f_1’ + f_0’) & bit > 0
\end{cases}
$$
解释一下:显然,对于所有 $i < 2^{bit - 1}$ (即二进制下第 $bit - 1$ 位为 $0$ ), $f_1’$ 是不会对它这个位置产生影响的,反而它会影响 $f_1’$ ,而影响的位置就是恰好二进制下只在 $bit - 1$ 位和它不一样(即为 $1$ 的位置),又因为 $f_1’$ 的下标是 $0 \sim 2^{bit - 1} - 1$ ,所有下标就是对应位置

当然,不可能真的递归去做,具体实现是时和 FFT 一样化成迭代

对于逆变换,就是 $merge(f_0’, f_1’ - f_0’)$

与运算

求 $h(n) = (f * g)(n) = \sum_{i \& j = n} f(i) * g(j)$

$h(n) = (f * g)(n) = \sum_{i & j = n} f(i) * g(j)$

直接给式子了:
$$
f’ =
\begin{cases}
f & bit = 0 \\
merge(f_0’ + f_1’, f_1’) & bit > 0
\end{cases}
$$
逆变换就是 $merge(f_0’ - f_1’, f_1’)$

异或运算

求 $h(n) = (f * g)(n) = \sum_{i \oplus j = n} f(i) * g(j)$

式子:
$$
f’ =
\begin{cases}
f & bit = 0 \\
merge(f_0’ + f_1’, f_0’ - f_1’) & bit > 0
\end{cases}
$$
逆变换有点不一样了,是 $merge(\frac{f_0’ + f_1’}{2}, \frac{f_0’ - f_1’}{2})$

其实,在 $\oplus$ 的操作下, $f’(i) = \sum_{\mid i \& j \mid \% 2 = 0} f(j) - \sum_{\mid i \& j \mid \% 2 = 1} f(j)$

$f’(i) = \sum_{\mid i & j \mid % 2 = 0} f(j) - \sum_{\mid i & j \mid % 2 = 1} f(j) = \sum (-1)^{\mid i & j \mid} f(j)$

另一种理解

其实还可以把 FWT 理解为特殊的 FFT ,我们考虑构造 $f’[i] = \sum_{j = 0}^{2^{bit} - 1} a_{i, j} f[j]$ ,即给定一个系数,那么,以异或卷积为例,设 $f * g = h$ :
$$
\begin{aligned}
h’(i) &= f’(i) * g’(i) \\
\sum_{i = 0}^{2^{bit} - 1} \sum_{j = 0}^{2^{bit} - 1} a_{k, i \oplus j}f[i]g[j] &= \sum_{i = 0}^{2^{bit} - 1} \sum_{j = 0}^{2^{bit} - 1} a_{k, i}f[i]a_{k, j}g[j] \\
a_{k, i \oplus j} &= a_{k, i} a_{k, j}
\end{aligned}
$$
考虑到异或运算对于每一位独立要有:
$$
\begin{aligned}
\begin{cases}
a_{i, 0} a_{i, 0} = a_{i, 0} \\
a_{i, 1} a_{i, 1} = a_{i, 0} \\
a_{i, 1} a_{i, 0} = a_{i, 1} \\
a_{i, 0} a_{i, 1} = a_{i, 1}
\end{cases} \\
并且要保证 a 有逆 \\
\end{aligned}
$$
有两组解:
$$
a =
\begin{bmatrix}
1 & 1 \\
1 & -1
\end{bmatrix}

\begin{bmatrix}
1 & -1 \\
1 & 1
\end{bmatrix}
$$
习惯上我们取前者,它的逆为:
$$
a^{-1} =
\begin{bmatrix}
\frac{1}{2} & \frac{1}{2} \\
\frac{1}{2} & -\frac{1}{2}
\end{bmatrix}
$$
这就是 FWT 的异或形式

板子

luoguP4717

要注意如果长度不为 $2$ 的整次幂,要补足

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>
using LL = long long;
const int N = (1 << 17) + 100, P = 998244353;
void adj(int &x){ x += (x >> 31) & P; }
namespace FWT
{
const int iv2 = 499122177; //2的逆元
void _or(int x[], int len)
{
for (int i = 2; i <= len; i <<= 1) //当前区间长度
for (int mid = i >> 1, j = 0; j < len; j += i) //j当前区间前端
for (int k = j; k < j + mid; ++k) adj(x[k + mid] += x[k] - P);
}
void u_or(int x[], int len)
{
for (int i = 2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k) adj(x[k + mid] -= x[k]);
}
void _and(int x[], int len)
{
for (int i = 2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k) adj(x[k] += x[k + mid] - P);
}
void u_and(int x[], int len)
{
for (int i = 2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k) adj(x[k] -= x[k + mid]);
}
void _xor(int x[], int len)
{
for (int i = 2, t1, t2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k)
{
t1 = x[k], t2 = x[k + mid];
adj(x[k] = t1 + t2 - P), adj(x[k + mid] = t1 - t2);
}
}
void u_xor(int x[], int len)
{
for (int i = 2, t1, t2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k)
{
t1 = x[k], t2 = x[k + mid];
adj(x[k] = t1 + t2 - P), adj(x[k + mid] = t1 - t2);
x[k] = LL(iv2) * x[k] % P, x[k + mid] = LL(iv2) * x[k + mid] % P;
}
}
}
int a[N], b[N], aa[N], bb[N], n, bit;
int main()
{
scanf("%d", &bit), n = (1 << bit);
for (int i = 0; i < n; ++i) scanf("%d", &a[i]);
for (int i = 0; i < n; ++i) scanf("%d", &b[i]);
for (int i = 0; i < n; ++i) aa[i] = a[i], bb[i] = b[i];
FWT::_or(aa, n), FWT::_or(bb, n);
for (int i = 0; i < n; ++i) aa[i] = LL(aa[i]) * bb[i] % P;
FWT::u_or(aa, n);
for (int i = 0; i < n; ++i) printf("%d ", aa[i]);
puts("");
for (int i = 0; i < n; ++i) aa[i] = a[i], bb[i] = b[i];
FWT::_and(aa, n), FWT::_and(bb, n);
for (int i = 0; i < n; ++i) aa[i] = LL(aa[i]) * bb[i] % P;
FWT::u_and(aa, n);
for (int i = 0; i < n; ++i) printf("%d ", aa[i]);
puts("");
for (int i = 0; i < n; ++i) aa[i] = a[i], bb[i] = b[i];
FWT::_xor(aa, n), FWT::_xor(bb, n);
for (int i = 0; i < n; ++i) aa[i] = LL(aa[i]) * bb[i] % P;
FWT::u_xor(aa, n);
for (int i = 0; i < n; ++i) printf("%d ", aa[i]);
puts("");
return 0;
}