Dyd's Blog

He who has a strong enough why can bear almost any how.

luoguP8292 [省选联考 2022] 卡牌

巧妙的根号分治

卡牌

思路

看到 $s_i \le 2000$ ,感觉很小,但打个表有 $300$ 多个质数,爆炸

这里有一个很套路的办法:考虑到大于 $\sqrt{s}$ 的质因数最多只有一个,我们可以单独记录那个质因数(不妨叫它大质数),剩下的质因数就只有 $14$ 个了;那么我们把每个 $s_i$ 就变成了 $\le 14$ 个小质数 + 一个大质数

(到了这里好像就可以大力容斥,比较 $2000$ 以内每个数只有 $\le 4$ 个质因数,但我容斥辣鸡的一批)

对于每次询问,依次考虑每一个大质数,把所有含有它的 $s_i$ 的小质数压缩成二进制 $k$ ,然后加到 $g[k]$ 中去,对 $g$ 做 FWT ,把所有 $g$ 乘起来即可,注意没有大质数的数也要乘,如果把 $g$ 的 FWT 放在外面预处理,时间为 $O(2^{14} (m + \sum c_i))$ ,我们发现这样常数巨大,不好卡(但好像确实可以卡过)

考虑换一种想法,显然我们只需要考虑所有询问了的质因数,平均一下每次询问也就十来个,不妨每次给 $s_i$ 重新分解,分解的范围也就仅限这十来个,我们给这些质数重新标 $id$ ,每次重新算 $g$ (因为质数不多,所以 FWT 很快),实测飞快

另外,对于过大(乘 $2$ 都超出范围)的质数,可以直接统计答案

代码

实现的时候有些细节:

  1. 所有因数只保留一次即可
  2. 注意真的贡献是 $2^{g(k)} - 1$ (对于没有大质数的数,就是 $2^{g(k)}$ ),所以 FWT 时要 $\mod \varphi(P)$
  3. 清空!尤其是清空 vector 时还要单独写函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <bits/stdc++.h>
#define pb push_back
using LL = long long;
const int S = 2000 + 100, P = 998244353, V = (1 << 14), C = 18000 + 100;
int n, m, pri[S], ct, si[S], mxv = 0, B, num[S], vis[S], lg[V + 100], p[C], id[S], f[V + 100], g[V + 100], id2[S], bit[V + 100];
std::vector<int> dv[S], v[S], dv2[S];
void clear(std::vector<int> &x)
{
std::vector<int> t;
std::swap(x, t);
}
void clear(int x[], int bit){ for (int i = 0; i < (1 << bit); ++i) x[i] = 0; }
int qpow(int x, int y)
{
int res = 1;
for (; y; y >>= 1, x = LL(x) * x % P) if (y & 1) res = LL(res) * x % P;
return res;
}
void adj(int &x, int _p = P){ x += (x >> 31) & _p; }
void prev()
{
num[1] = 1;
for (int i = 2; i < S; ++i)
{
if (!vis[i]){ vis[i] = num[i] = pri[++ct] = i; }
for (int j = 1, t; j <= ct && (t = pri[j] * i) < S; ++j)
{
vis[t] = pri[j];
if (!(i % pri[j])){ num[t] = num[i]; break; }
num[t] = num[i] * pri[j];
}
}
lg[1] = 0;
for (int i = 2; i <= V; ++i) lg[i] = lg[i >> 1] + 1;
}
void fwt(int x[], int len)
{
for (int i = 2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k) adj(x[k + mid] += x[k] - P);
}
void ifwt(int x[], int len)
{
for (int i = 2; i <= len; i <<= 1)
for (int mid = i >> 1, j = 0; j < len; j += i)
for (int k = j; k < j + mid; ++k) adj(x[k + mid] -= x[k]);
}
int main()
{
prev();
scanf("%d", &n);
for (int i = 1, t; i <= n; ++i)
{
scanf("%d", &t);
++si[num[t]], mxv = std::max(num[t], mxv);
}
B = std::sqrt(mxv);
for (int i = 2; i <= mxv; ++i) if (si[i])
for (int j = i, t; j > 1; )
do dv[i].pb(t = vis[j]), j /= t; while (!(j % t));
memset(vis, 0, sizeof vis);
scanf("%d", &m);
for (int tot, sol, ans, tt2; m--; )
{
scanf("%d", &n), tot = 0, ans = sol = 1;
for (int i = 1; i <= n; ++i) scanf("%d", &p[i]);
std::sort(p + 1, p + n + 1), n = std::unique(p + 1, p + n + 1) - p - 1;
for (int i = 1; i <= n; ++i) id[p[i]] = i, clear(v[i]), tot += (p[i] <= B);
clear(f, tot);
for (int i = 1, t, sm, bg; i <= mxv; ++i) if (si[i])
{
sm = bg = 0, clear(dv2[i]);
for (int j : dv[i]) if (t = id[j])
{
dv2[i].pb(j);
(p[t] <= B) ? sm |= (1 << (t - 1)) : bg = t;
}
if (!bg) adj(f[sm] += si[i] - (P - 1), P - 1);
else v[bg].pb(i);
}
fwt(f, 1 << tot);
for (int i = 0; i < (1 << tot); ++i) f[i] = qpow(2, f[i]);
for (int i = 1, sm, t; i <= n && sol; ++i) if (p[i] > B)
{
if (!v[i].size()) sol = 0;
else if (p[i] * 2 > mxv) ans = LL(qpow(2, si[p[i]]) - 1) * ans % P;
else
{
for (int j : v[i])
for (int k : dv2[j]) vis[id[k]] = 1;
tt2 = 0;
for (int j = 1; j <= n && p[i] * p[j] <= mxv; ++j) if (vis[j])
{
vis[j] = 0;
id2[j] = ++tt2;
}
clear(g, tt2);
for (int j : v[i])
{
sm = 0;
for (int k : dv2[j]) if (t = id2[id[k]]) sm |= (1 << (t - 1));
adj(g[sm] += si[j] - (P - 1), P - 1);
}
fwt(g, 1 << tt2);
for (int i = 0; i < (1 << tt2); ++i) g[i] = qpow(2, g[i]) - 1;
f[0] = LL(f[0]) * g[0] % P;
for (int i = 1; i < (1 << tot); ++i)
{
t = id2[lg[i & -i] + 1];
bit[i] = bit[i & (i - 1)] | (t ? (1 << (t - 1)) : 0);
f[i] = LL(f[i]) * g[bit[i]] % P;
}
for (int j = 1; j <= n && p[i] * p[j] <= mxv; ++j) id2[j] = 0;
}
}
if (sol)
{
ifwt(f, 1 << tot);
printf("%d\n", LL(ans) * f[(1 << tot) - 1] % P);
}
else puts("0");
for (int i = 1; i <= n; ++i) id[p[i]] = 0;
}
return 0;
}