Dyd's Blog

He who has a strong enough why can bear almost any how.

luoguP5824 十二重计数法

组合数学大杂烩

十二重计数法

尝试提高自己没救的数学水平

思路

分讨

I

就是每个球有 $m$ 种选择,所以 $m^n$

II

考虑选 $n$ 个盒子来装,由于球和盒子带标号,要乘一个 $n!$ (没装球的盒子显然不排列),所以 $\binom{m}{n}n!$

III

想不出组合意义了,搞 EGF 太麻烦,考虑容斥,枚举多少个盒子一定是空的,那么就是 $\sum_{i = 0}^{m} (-1)^{i} \binom{m}{i} (m - i)^n$ (我容斥太拉了想半天)

IV

想了半天想不出,瞟一眼题解:第二类斯特林数 · 行(粗口),斯特林数完全超纲好吧,贺一个斯特林数,答案是 $\sum_{i = 0}^m {n \brace i}$

V

就是把球放进去,答案是 $[n \le m]$

VI

就是斯特林数 · 行(啊啊啊又超纲,虽然超到一个地方去了),答案就是 ${n \brace m}$

VII

球一样了,直接插板法, $\binom{n + m - 1}{m - 1}$

VIII

选 $n$ 个盒子来装,就是 $\binom{m}{n}$

IX

由于球相同,先给每个盒子放一个球就转化为了 VII

X

没法了,搞 GF ,先考虑递推,设 $f(n, m)$ 是“ $n$ 个球 $m$ 个盒的方案数”,那么 $f(n, m) = f(n, m - 1) + f(n - m, m)$ 即:要么直接加一个空盒,要么所有盒子球个数一起加一

构造一个 OGF $F_i(x) = \sum_{j = 0} f(j, i) x^i$ ,考虑化简,即带入递推式:
$$
\begin{aligned}
& F_i(x) = F_{i - 1}(x) + x^i F_i(x) \\
& F_i(x) = \frac{F_{i - 1}(x)}{1 - x^i} \\
& F_i(x) = \prod_{j = 1}^i \frac{1}{1 - x^j}
\end{aligned}
$$
这玩意到封闭不封闭的,看到连乘考虑 $\ln$ :
$$
\begin{aligned}
& F_i(x) = \prod_{j = 0}^i \frac{1}{1 - x^j} \\
& \ln(F_i(x)) = \sum_{j = 1}^i -\ln(1 - x^j)
\end{aligned}
$$
现在我们要快速求 $\ln(1 - x^k)$ ,考虑求导:
$$
\begin{aligned}
& \ln(1 - x^k)’ \\
= & \frac{1}{1 - x^k} * (-kx^{k - 1}) \\
= & (-k x^{k - 1}) \sum_{i = 0} x^{ik} \\
= & \sum_{i = 0} -k x^{ik + k - 1} \\
= & \sum_{i = 1} -k x^{ik - 1}
\end{aligned}
$$
现在积分回去得:
$$
\ln(1 - x^k) = \sum_{i = 1} -\frac{1}{i} x^{ik}
$$
发现 $\ln$ 前面其实也有个负号,可以消:
$$
\begin{aligned}
\ln(F_i(x)) = & \sum_{j = 1}^i -\ln(1 - x^j) \\
= & \sum_{j = 1}^i \sum_{k = 1} \frac{1}{k} x^{kj}
\end{aligned}
$$

我们要求的是 $[x^n] F_m(x)$ ,在模 $x^{n + 1}$ 意义下完成多项式构造和多项式 $\exp$ 即可,总时间 $O(n \log n)$

XI

还是 $[n \le m]$

XII

先给每个盒子分配一个球就是 X 了

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <bits/stdc++.h>
#define si size()
#define rs(x) resize(x)
using LL = long long;
using poly = std::vector<int>;
const int P = 998244353, N = 2e5 + 100, G = 3, L = 2e5 + 1;
int fac[N << 1], ifac[N << 1], iv[1 << 20 | 100], rev[1 << 20 | 100];
poly string, fm;
int qpow(int x, int y = P - 2)
{
int res = 1;
for (; y; y >>= 1, x = LL(x) * x % P) if (y & 1) res = LL(res) * x % P;
return res;
}
void adj(int &x){ x += (x >> 31) & P; }
void rdy(int &bit, int &tot, int len){ for (bit = 0, tot = 1; tot < len + 1; tot <<= 1, ++bit); }
void get_r(int bit, int tot){ for (int i = 1; i < tot; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); }
void ntt(int *x, int tot, int op)
{
static int gn[1 << 20 | 100];
for (int i = 1; i < tot; ++i) if (i < rev[i]) std::swap(x[i], x[rev[i]]);
for (int mid = 1; mid < tot; mid <<= 1)
{
int g0 = qpow(G, (P - 1) / (mid << 1));
if (op == -1) g0 = qpow(g0);
gn[0] = 1;
for (int i = 1; i < mid; ++i) gn[i] = LL(g0) * gn[i - 1] % P;
for (int i = 0; i < tot; i += (mid << 1))
for (int *x1 = x + i, *x2 = x + i + mid, *ed = x2, *g = gn; x1 != ed; ++x1, ++x2, ++g)
{
int p = *x1, q = LL(*x2) * *g % P;
adj(*x1 = p + q - P), adj(*x2 = p - q);
}
}
if (op == 1) return ;
int t = qpow(tot);
for (int i = 0; i < tot; ++i) x[i] = LL(x[i]) * t % P;
}
poly operator * (poly x, poly y)
{
if (x.empty() || y.empty()) return {};
int n = x.si, m = y.si, tot, bit;
rdy(bit, tot, n + m), get_r(bit, tot);
x.rs(tot), y.rs(tot);
if (x != y) ntt(x.data(), tot, 1), ntt(y.data(), tot, 1);
else ntt(x.data(), tot, 1), y = x;
for (int i = 0; i < tot; ++i) x[i] = LL(x[i]) * y[i] % P;
ntt(x.data(), tot, -1);
return x.rs(n + m - 1), x;
}
poly operator - (poly x, poly y)
{
if (x.si < y.si) x.rs(y.si);
for (int i = y.si - 1; ~i; --i) adj(x[i] -= y[i]);
return x;
}
poly inv(poly x)
{
int n = x.si;
if (n == 1) return {qpow(x[0])};
poly y = inv(poly(x.begin(), x.begin() + ((n + 1) >> 1))), z = y * y * x, res(n);
y.rs(n), z.rs(n);
for (int i = 0; i < n; ++i) adj(res[i] = 2 * y[i] - z[i]);
return res;
}
poly dif(poly x)
{
poly res(x.si - 1);
for (int i = x.si - 1; i; --i) res[i - 1] = LL(i) * x[i] % P;
return res;
}
poly inte(poly x)
{
poly res(x.si + 1);
for (int i = x.si - 1; ~i; --i) res[i + 1] = LL(iv[i + 1]) * x[i] % P;
return res;
}
poly ln(poly x)
{
poly res(inv(x) * dif(x));
res.rs(x.si), res = inte(res);
return res.rs(x.si), res;
}
poly exp(poly x)
{
int n = x.si;
if (n == 1) return {1};
poly y = exp(poly(x.begin(), x.begin() + ((n + 1) >> 1)));
y.rs(n);
poly z = ln(y);
x = x - z, ++x[0], x = x * y;
return x.rs(n), x;
}
void prev(int n, int m)
{
int mx = std::max(n, m);
fac[0] = fac[1] = ifac[0] = ifac[1] = 1;
for (int i = 2; i <= (mx << 1); ++i) fac[i] = LL(i) * fac[i - 1] % P;
ifac[mx << 1] = qpow(fac[mx << 1]);
for (int i = (mx << 1) - 1; i > 1; --i) ifac[i] = LL(i + 1) * ifac[i + 1] % P;
iv[1] = 1;
for (int i = 2; i <= mx; ++i) iv[i] = LL(P - P / i) * iv[P % i] % P;
poly f(m + 1), g(m + 1);
for (int i = 0; i <= m; ++i)
{
g[i] = (i & 1) ? P - ifac[i] : ifac[i];
f[i] = qpow(i, n) * LL(ifac[i]) % P;
}
string = f * g;
f.rs(n + 1);
for (int i = 0; i <= n; ++i) f[i] = 0;
for (int i = 1; i <= m; ++i)
for (int j = 1; i * j <= n; ++j) adj(f[i * j] += iv[j] - P);
fm = exp(f);
}
int C(int x, int y)
{
if (x < y) return 0;
return LL(fac[x]) * ifac[y] % P * ifac[x - y] % P;
}
int calc1(int n, int m){ return qpow(m, n); }
int calc2(int n, int m)
{
if (n > m) return 0;
return LL(fac[n]) * C(m, n) % P;
}
int calc3(int n, int m)
{
if (n < m) return 0;
int res = 0;
for (int i = 0; i <= m; ++i)
{
int t = LL(qpow(m - i, n)) * C(m, i) % P;
if (i & 1) adj(res -= t);
else adj(res += t - P);
}
return res;
}
int calc4(int n, int m)
{
int res = 0;
for (int i = 0; i <= m; ++i) adj(res += string[i] - P);
return res;
}
int calc5(int n, int m){ return n > m ? 0 : 1; }
int calc6(int n, int m){ return string[m]; }
int calc7(int n, int m){ return C(n + m - 1, m - 1); }
int calc8(int n, int m){ return C(m, n); }
int calc9(int n, int m){ return n < m ? 0 : calc7(n - m, m); }
int calc10(int n, int m){ return fm[n]; }
int calc11(int n, int m){ return n > m ? 0 : 1; }
int calc12(int n, int m){ return n < m ? 0 : calc10(n - m, m); }
int main()
{

int n, m;
scanf("%d %d", &n, &m);
prev(n, m);
printf("%d\n", calc1(n, m));
printf("%d\n", calc2(n, m));
printf("%d\n", calc3(n, m));
printf("%d\n", calc4(n, m));
printf("%d\n", calc5(n, m));
printf("%d\n", calc6(n, m));
printf("%d\n", calc7(n, m));
printf("%d\n", calc8(n, m));
printf("%d\n", calc9(n, m));
printf("%d\n", calc10(n, m));
printf("%d\n", calc11(n, m));
printf("%d\n", calc12(n, m));
return 0;
}