Dyd's Blog

He who has a strong enough why can bear almost any how.

任意模数多项式乘法

四次 FFT

任意模数多项式乘法

好像叫 MTT

思路

拆系数,设 $M$ 为一个 $\sqrt{P}$ 级别的数,设:
$$
F(x) = M * F1(x) + F0(x) \\
G(x) = M * G1(x) + G0(x)
$$
其中 $[x^n]F1(x) = \lfloor \frac{[x^n]F(x)}{M} \rfloor$ , $[x^n]F0(x) = [x^n]F(x) \mod M$ ,同理求 $G1, G0$ ,于是系数被降到 $10^5 M^2$ 级别,考虑它们的乘法:
$$
\begin{aligned}
F(x)G(x) = M^2 * F1(x) * G1(x) + M * (F1(x) * G0(x) + F0(x) * G1(x)) + F0(x) * G0(x)
\end{aligned}
$$
发现直接算的话调用 fft 的次数过多,考虑合并

先看 DFT ,正常来说我们要做 $4$ 次,但我们发现多项式的虚部全为 $0$ ,部分把两个合并,如对于 $F1, F2$ ,构造 $A(x) = F1(x) + i F0(x), B(x) = F1(x) - iF0(x)$ ,那么 $A, B$ 的每一个点值互为共轭,那么只要对 $A$ 做 fft , $B$ 的点值也就知道了;于是,我们用两次 DFT 完成了求 $F1, F2, G1, G2$ 的点值

现在考虑 IDFT :我们有四个多项式要做:$F1G1, F1G0, F0G1, F0G0$ ,此时它们的虚部已经不为 $0$ 了,但我们知道这四个多项式卷起来后的系数表示中虚部一定为 $0$ ,所以构造 $A(x) = F1(x)G1(x) + iF1(x)G0(x), B(x) = F0(x)G1(x) + iF0(x)G0(x)$ , 那么 IDFT 后 $A$ 的实和虚部就是 $F1G1, F1G0$ ,同理 $B$ 的实部和虚部就是 $F0G1, F0G0$

综上,我们只要 $4$ 次 fft 即可

代码

以下代码使用了 std::complex<typename> ,它包含在 #include <complex> 中,常数略大

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <bits/stdc++.h>
#define STC static
using LL = long long;
using DB = double;
using CP = std::complex<double>;
const int N = 1e5 + 100;
const CP I(0, 1);
int tot, bit, r[N << 2], P, M;
void fft(CP *x, int op)
{
STC CP wn[N << 2];
int i, mid;
CP *x1, *x2, *ed, *w;
for (i = 1; i < tot; ++i) if (i < r[i]) std::swap(x[i], x[r[i]]);
for (mid = 1; mid < tot; mid <<= 1)
{
CP w0(std::cos(M_PI / mid), std::sin(M_PI / mid * op));
wn[0] = {1, 0};
for (i = mid - 2; i >= 0; i -= 2) wn[i + 1] = (wn[i] = wn[i >> 1]) * w0; //这里不能写~i
for (i = 0; i < tot; i += (mid << 1))
for (x1 = x + i, x2 = x1 + mid, ed = x2, w = wn; x1 != ed; ++x1, ++x2, ++w)
{
CP p = *x1, q = *x2 * *w;
*x1 = p + q, *x2 = p - q;
}
}
if (op == 1) return ;
DB t = DB(1) / tot;
for (i = 0; i < tot; ++i) x[i] = x[i] * t;
}
void dfft(CP *x, CP *y)
{
int i;
for (i = 0; i < tot; ++i) x[i] = x[i] + I * y[i];
fft(x, 1);
for (i = 0; i < tot; ++i) y[i] = std::conj(x[i ? tot - i : 0]);
for (i = 0; i < tot; ++i)
{
CP p = x[i], q = y[i];
x[i] = (p + q) * DB(0.5);
y[i] = (q - p) * DB(0.5) * I;
}
}
LL rod(CP x)
{
DB t = x.real();
return t < 0 ? LL(t - 0.5) % P : LL(t + 0.5) % P;
}
void polymul(int *f, int *g, int lf, int lg)
{
STC CP p[N << 2], q[N << 2], f0[N << 2], f1[N << 2], g0[N << 2], g1[N << 2];
int i;
tot = 1, bit = 0;
while (tot < lf + lg + 1) tot <<= 1, ++bit;
for (i = 1; i < tot; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
for (i = 0; i <= lf; ++i) f0[i] = f[i] / M, f1[i] = f[i] % M;
for (i = 0; i <= lg; ++i) g0[i] = g[i] / M, g1[i] = g[i] % M;
dfft(f0, f1), dfft(g0, g1);
for (i = 0; i < tot; ++i)
{
p[i] = f0[i] * g0[i] + I * f1[i] * g0[i];
q[i] = f0[i] * g1[i] + I * f1[i] * g1[i];
}
fft(p, -1), fft(q, -1);
for (i = lf + lg; ~i; --i) f[i] = (M * M * rod(p[i].real()) % P + M * (rod(p[i].imag()) + rod(q[i].real())) % P + rod(q[i].imag())) % P;
}
int f[N << 2], g[N << 2], n, m;
int main()
{
scanf("%d %d %d", &n, &m, &P);
M = std::sqrt(P) + 1;
for (int i = 0; i <= n; ++i) scanf("%d", &f[i]), f[i] %= P;
for (int i = 0; i <= m; ++i) scanf("%d", &g[i]), g[i] %= P;
polymul(f, g, n, m);
for (int i = 0; i <= n + m; ++i) printf("%d ", f[i]);
puts("");
return 0;
}