Dyd's Blog

He who has a strong enough why can bear almost any how.

luoguP7735 [NOI2021] 轻重边

第一步的转化非常巧妙

轻重边

思路

最开始想的 LCT ,发现取消重边这个操作不好维护,这里有一个巧妙的转化:把每次操作当成一次染色,每次染一种新颜色,一条边是重边当且仅当两端点颜色相同,然后就是上树剖维护连续颜色即可

代码

dbug 时的几个点:

  1. 一条链上边比点少一个,所以 ct = len - 1
  2. 初始化 num
  3. up 时记得改 col

最后, ask 里我是把 lca 求得然后两端分开跳的,但好像也可以不必,只要交换 $u, v$ 的同时也交换 $lastu, lastv$ 即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include <iostream>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <array>
#include <functional>
const int N = 1e5 + 100;
int n, m, h[N], idx, id[N], num, dep[N], si[N], top[N], fa[N], son[N];
struct Edge{ int ver, ne; } e[N << 1];
void add(int x, int y){ e[idx] = {y, h[x]}, h[x] = idx++; }
namespace LT
{
struct Node{ int tag, lcol, rcol, cnt, len; } tr[N << 2];
#define tg(x) tr[(x)].tag
#define cl(x) tr[(x)].lcol
#define cr(x) tr[(x)].rcol
#define ct(x) tr[(x)].cnt
#define len(x) tr[(x)].len
#define lc (u << 1)
#define rc ((u << 1) | 1)
#define mid ((l + r) >> 1)
void bd(int u, int l, int r)
{
tg(u) = cl(u) = cr(u) = ct(u) = 0;
len(u) = r - l + 1;
if (l < r) bd(lc, l, mid), bd(rc, mid + 1, r);
}
void adt(int u, int d){ tg(u) = cl(u) = cr(u) = d, ct(u) = len(u) - 1; }
void dw(int u){ if (tg(u)){ adt(lc, tg(u)), adt(rc, tg(u)), tg(u) = 0; } }
void up(int u)
{
ct(u) = ct(lc) + ct(rc);
if (cl(rc) == cr(lc) && cl(rc) != 0) ++ct(u);
cl(u) = cl(lc), cr(u) = cr(rc);
}
void mdf(int ql, int qr, int d, int u = 1, int l = 1, int r = num)
{
if (l >= ql && r <= qr) return adt(u, d);
dw(u);
if (ql <= mid) mdf(ql, qr, d, lc, l, mid);
if (qr > mid) mdf(ql, qr, d, rc, mid + 1, r);
up(u);
}
std::array<int, 3> ask(int ql, int qr, int u = 1, int l = 1, int r = num)
{
if (l >= ql && r <= qr) return {ct(u), cl(u), cr(u)};
dw(u);
if (ql > mid) return ask(ql, qr, rc, mid + 1, r);
else if (qr <= mid) return ask(ql, qr, lc, l, mid);
else
{
auto rl = ask(ql, qr, lc, l, mid), rr = ask(ql, qr, rc, mid + 1, r);
rl[0] = rl[0] + rr[0] + (rl[2] == rr[1] && rl[2] != 0);
rl[2] = rr[2];
return rl;
}
}
#undef tg
#undef cl
#undef cr
#undef ct
#undef len
#undef lc
#undef rc
#undef mid
}
void dfs1(int x, int f, int d)
{
dep[x] = d, fa[x] = f, si[x] = 1, son[x] = 0;
for (int i = h[x], y; ~i; i = e[i].ne) if ((y = e[i].ver) != f)
{
dfs1(y, x, d + 1);
si[x] += si[y];
if (si[son[x]] < si[y]) son[x] = y;
}
}
void dfs2(int x, int t)
{
id[x] = ++num, top[x] = t;
if (!son[x]) return ;
dfs2(son[x], t);
for (int i = h[x], y; ~i; i = e[i].ne) if ((y = e[i].ver) != fa[x])
{
if (y == son[x]) continue;
dfs2(y, y);
}
}
int lca(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
u = fa[top[u]];
}
return (dep[u] < dep[v]) ? u : v;
}
void mdf(int u, int v, int d)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
LT::mdf(id[top[u]], id[u], d);
u = fa[top[u]];
}
if (dep[u] < dep[v]) std::swap(u, v);
LT::mdf(id[v], id[u], d);
}
int ask(int u, int v)
{
int res = 0, lastu = 0, lastv = 0, w = lca(u, v);
std::function<void(int&, int&)> jump = [&](int &x, int &last)
{
while (top[x] != top[w])
{
auto t = LT::ask(id[top[x]], id[x]);
res += t[0] + (t[2] == last && last != 0);
last = t[1];
x = fa[top[x]];
}
} ;
jump(u, lastu), jump(v, lastv);
if (dep[u] < dep[v]) std::swap(u, v), std::swap(lastu, lastv);
auto t = LT::ask(id[v], id[u]);
res += t[0] + (t[1] == lastv && lastv != 0) + (t[2] == lastu && lastu != 0);
return res;
}
int main()
{
int T;
for (scanf("%d", &T); T--; )
{
scanf("%d %d", &n, &m);
memset(h, -1, sizeof h), idx = num = 0;
for (int i = 1, u, v; i < n; ++i){ scanf("%d %d", &u, &v), add(u, v), add(v, u); }
dfs1(1, 0, 1), dfs2(1, 1), LT::bd(1, 1, num);
for (int op, a, b, ccol = 0; m--; )
{
scanf("%d %d %d", &op, &a, &b);
if (op == 1) mdf(a, b, ++ccol);
else printf("%d\n", ask(a, b));
}
}
return 0;
}