Dyd's Blog

He who has a strong enough why can bear almost any how.

luoguP6730 [WC2020] 猜数游戏

bsgs + gcd 建图,缩点后统计

猜数游戏

思路

考虑建出无向图,每次只考虑询问的点集,发现选了一个点,它可达的所有点都确定,缩点后把所有无入度的点选了

但枚举询问点集复杂度过高,套路拆贡献,计算一个 SCC 会被选多少次,预处理出可以到改 SCC 的点的个数,那么这些点都不可选,而 SCC 中至少要选一个点,其它的无所谓,直接计算出来即可

现在考虑如何建图,先计算出 $p$ 的原根,把 $a_i$ 写成 $g^{k_i}$ ( $O(n \sqrt{p} * t)$ ),则 $i \to j$ 当且仅当 $k_i * x \equiv k_j \pmod {\varphi(p)}$ 有解( $O(n^2 \log \varphi(p))$ )

代码

卡常大赛

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#define eb emplace_back
#define IL inline
using LL = long long;
const int N = 5000 + 100, P = 998244353, inf = 0x3f3f3f3f;
int n, p, a[N], g, phi, q, pw[N], ans;
int h[N], idx, hv[N], idv;
int dfn[N], cnt, c[N], scc[N];
struct Edge{ int ne, ver; } e[N * N], ev[N * N];
IL void add(int x, int y){ e[idx] = {h[x], y}, h[x] = idx++; }
IL void adv(int x, int y){ ev[idv] = {hv[x], y}, hv[x] = idv++; }
IL void adj(int &x, int mod){ x += (x >> 31) & mod; }
IL int qpow(int x, int y, int mod)
{
int res = 1;
for (; y; y >>= 1, x = LL(x) * x % mod) if (y & 1) res = LL(res) * x % mod;
return res;
}
IL void prev()
{
for (q = 3; p % q; ++q);
pw[0] = 1;
for (int i = 1; i <= n; ++i) adj(pw[i] = (pw[i - 1] << 1) - P, P);
std::vector<int> pri;
g = phi = p;
for (int i = 2; LL(i) * i <= g; ++i) if (g % i == 0)
{
phi = phi / i * (i - 1);
do g /= i;
while (g % i == 0);
}
if (g ^ 1) phi = phi / g * (g - 1);
g = phi;
for (int i = 2; LL(i) * i <= g; ++i) if (g % i == 0)
{
pri.eb(i);
do g /= i;
while (g % i == 0);
}
if (g ^ 1) pri.eb(g);
for (g = 2; true; ++g)
{
for (int i : pri) if (qpow(g, phi / i, p) == 1) goto E_F_g;
break;
E_F_g:;
}
}
IL int exgcd(int a, int b, int &x, int &y)
{
if (!b) return x = 1, y = 0, a;
int d = exgcd(b, a % b, y, x);
y -= (a / b) * x;
return d;
}
IL int gcd(int a, int b){ return b ? gcd(b, a % b) : a; }
IL int bsgs(int a, int b, int p)
{
if (b == 1 || p == 1) return 0;
static int k;
static __gnu_pbds::gp_hash_table<int, int> ha;
if (ha.empty())
{
k = std::sqrt(p) + 1;
for (int i = 0, j = 1; i < k; j = LL(j) * a % p, ++i) ha[j] = i;
}
int ak = qpow(a, k, p), x, y;
exgcd(b, p, x, y), adj(x, p);
for (int i = 1, j = LL(ak) * x % p; i <= k; j = LL(j) * ak % p, ++i) if (ha.find(j) != ha.end()) return i * k - ha[j];
return -inf;
}
IL void tar(int x)
{
static int low[N], num, stk[N], top;
static bool ins[N];
dfn[x] = low[x] = ++num, ins[stk[++top] = x] = true;
for (int i = h[x], y; ~i; i = e[i].ne)
if (!dfn[y = e[i].ver]) tar(y), low[x] = std::min(low[x], low[y]);
else if (ins[y]) low[x] = std::min(low[x], dfn[y]);
if (dfn[x] ^ low[x]) return ;
++cnt;
int y;
do ins[y = stk[top--]] = false, ++scc[c[y] = cnt]; while (x != y);
}
int main()
{
std::memset(h, -1, sizeof h), idx = 0, std::memset(hv, -1, sizeof hv), idv = 0;
scanf("%d %d", &n, &p), prev();
{
static __gnu_pbds::gp_hash_table<int, int> ha;
static bool is[N];
for (int i = 1; i <= n; ++i)
{
scanf("%d", a + i);
if (a[i] % q) a[i] = bsgs(g, a[i], p);
else ha[a[i]] = i, is[i] = true;
}
for (int i = 1; i <= n; ++i) if (is[i])
for (int cur = LL(a[i]) * a[i] % p; cur; cur = LL(cur) * a[i] % p) if (ha.find(cur) != ha.end()) add(i, ha[cur]);
for (int i = 1, t; i <= n; ++i)
{
t = gcd(a[i], phi);
for (int j = 1; j <= n; ++j) if (!is[i] && !is[j] && i != j && a[j] % t == 0) add(i, j);
}
}
for (int i = 1; i <= n; ++i) if (!dfn[i]) tar(i);
{
static int vis[N][N];
for (int x = 1; x <= n; ++x)
for (int i = h[x], y; ~i; i = e[i].ne) if (c[y = e[i].ver] != c[x] && !vis[c[y]][c[x]]) vis[c[y]][c[x]] = true, adv(c[y], c[x]);
}
for (int x = 1, ban; x <= cnt; ++x)
{
ban = scc[x];
for (int i = hv[x]; ~i; i = ev[i].ne) ban += scc[ev[i].ver];
adj(ans += pw[n - ban] * LL(pw[scc[x]] - 1) % P - P, P);
}
printf("%d\n", ans);
return 0;
}