Dyd's Blog

He who has a strong enough why can bear almost any how.

点分治

连分治也变得难了起来……

点分治

定义

点分治是树上分治的种常用方法,主要思想是每次在树上选一个点,将整棵树的问题划分为两类(如图):三角形的子树内问题和过了点跨子树的问题

点和子树

然后每个子树也这样划分,这个点每次取重心,可以保证最多划分 $\log n$ 层

模板

题意非常简单:求树上距离不超过 $k$ 的点对数量,点分治的思路也非常简单,每次选重心(记选的节点为 $c$ ),分三类:

  1. 对于两个点都在同一子树内部的情况,递归处理
  2. 对于有一个点恰好是 $c$ 的情况,直接dfs求
  3. 对于跨子树的情况,可以先求出每棵子树内每个点到 $c$ 的距离,然后对于所有距离,记录任选两个距离和小于等于 $k$ 的情况,再删掉同一棵子树内两个点距离和小于等于 $k$ 的情况即可,而求解“一个集合内任取两个数和小于等于 $k$ 的方案数”可以用排序后双指针来解决(也可以排序后二分,麻烦点)

考虑时间复杂度,最多有 $\log n$ 层,每层 $n$ 个点都要排序,一共是 $O(n \log^2 n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include <bits/stdc++.h>
using namespace std;
const int N = 1e4 + 5;
int n, k;
struct Edge
{
int ne, ver, w;
} e[N << 1];
int h[N], idx;
bool del[N];
int p[N], q[N], cp, cq;
void add(int x, int y, int z)
{
e[idx] = (Edge){h[x], y, z}, h[x] = idx++;
}
int get_si(int x, int fa)
{
if (del[x])
return 0;
int res = 1;
for (int i = h[x]; i != -1; i = e[i].ne)
if (e[i].ver != fa)
res += get_si(e[i].ver, x);
return res;
}
int get_wc(int x, int fa, int si, int &wc) // 求重心(其实是一个保证删去后子树大小小于n/2的点,不一定是重心)
{
if (del[x])
return 0;
int sum = 1, mx = 0;
for (int i = h[x], y, t; i != -1; i = e[i].ne)
{
y = e[i].ver;
if (y == fa)
continue;
t = get_wc(y, x, si, wc);
mx = max(mx, t);
sum += t;
}
mx = max(mx, si - sum);
if (mx <= si / 2)
wc = x;
return sum;
}
void get_dis(int x, int fa, int dis)
{
if (del[x])
return;
q[++cq] = dis;
for (int i = h[x]; i != -1; i = e[i].ne)
if (e[i].ver != fa)
get_dis(e[i].ver, x, dis + e[i].w);
}
int work(int a[], int c) //计算集合a中有多少对相加不大于k
{
sort(a + 1, a + 1 + c);
int res = 0;
for (int i = c, j = 0; i >= 1; --i)
{
while (j + 1 < i && a[j + 1] + a[i] <= k)
++j;
j = min(j, i - 1);
res += j + 1;
}
return res;
}
int calc(int x)
{
if (del[x])
return 0;
int res = 0;
get_wc(x, -1, get_si(x, -1), x);
del[x] = true;
cp = 0;
for (int i = h[x], y; i != -1; i = e[i].ne)
{
y = e[i].ver;
cq = 0;
get_dis(y, -1, e[i].w);
res -= work(q, cq);
for (int j = 1; j <= cq; ++j)
{
if (q[j] <= k)
++res;
p[++cp] = q[j];
}
}
res += work(p, cp);
for (int i = h[x]; i != -1; i = e[i].ne)
res += calc(e[i].ver);
return res;
}
int main()
{
while (scanf("%d%d", &n, &k), n || k)
{

for (int i = 1; i <= n; ++i)
h[i] = -1, del[i] = false;;
idx = 0;
for (int i = 1, u, v, w; i < n; ++i)
{
scanf("%d%d%d", &u, &v, &w);
add(u + 1, v + 1, w), add(v + 1, u + 1, w); //输入的下标是从0开始的
}
printf("%d\n", calc(1));
}
return 0;
}

例题

权值

类似于模板,看注释吧,时间复杂度 $O(n \log n)$ :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5, S = 1e6 + 5, INF = 0x3f3f3f3f;
int n, k, ans = INF;
struct Edge
{
int ne, ver, w;
} e[N << 1];
struct Node
{
int dis, num;
} p[N], q[N];
int h[N], idx;
bool del[N];
int cp, cq;
int b[S]; //开一个桶记录到重心距离为i的点的最小边数
void add(int x, int y, int z)
{
e[idx] = (Edge){h[x], y, z}, h[x] = idx++;
}
int get_si(int x, int fa)
{
if (del[x])
return 0;
int res = 1;
for (int i = h[x]; i != -1; i = e[i].ne)
if (e[i].ver != fa)
res += get_si(e[i].ver, x);
return res;
}
int get_wc(int x, int fa, int si, int &wc)
{
if (del[x])
return 0;
int sum = 1, mx = 0;
for (int i = h[x], y, t; i != -1; i = e[i].ne)
{
y = e[i].ver;
if (y == fa)
continue;
t = get_wc(y, x, si, wc);
mx = max(mx, t);
sum += t;
}
mx = max(mx, si - sum);
if (mx <= si / 2)
wc = x;
return sum;
}
void get_dis(int x, int fa, int dis, int num)
{
if (del[x] || dis > k)
return;
q[++cq] = (Node){dis, num};
for (int i = h[x]; i != -1; i = e[i].ne)
if (e[i].ver != fa)
get_dis(e[i].ver, x, dis + e[i].w, num + 1);
}
void calc(int x)
{
if (del[x])
return ;
get_wc(x, -1, get_si(x, -1), x);
del[x] = true;
cp = 0;
for (int i = h[x], y; i != -1; i = e[i].ne)
{
y = e[i].ver;
cq = 0;
get_dis(y, x, e[i].w, 1);
for (int j = 1; j <= cq; ++j)
{
if (q[j].dis == k)
ans = min(ans, q[j].num);
ans = min(ans, b[k - q[j].dis] + q[j].num);
p[++cp] = q[j];
}
for (int j = 1; j <= cq; ++j)
b[q[j].dis] = min(b[q[j].dis], q[j].num);
}
for (int j = 1; j <= cp; ++j) //将桶清空
b[p[j].dis] = INF;
for (int i = h[x]; i != -1; i = e[i].ne)
calc(e[i].ver);
return ;
}
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; ++i)
h[i] = -1;
idx = 0;
for (int i = 0; i <= k; ++i)
b[i] = INF;
for (int i = 1, u, v, w; i < n; ++i)
{
scanf("%d%d%d", &u, &v, &w);
add(u + 1, v + 1, w), add(v + 1, u + 1, w);
}
calc(1);
if (ans == INF)
ans = -1;
printf("%d\n", ans);
return 0;
}

【模板】点分治1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <bits/stdc++.h>
using namespace std;
const int N = 1e4 + 5, K = 1e7 + 5, M = 100 + 5;
int n, m;
struct Edge
{
int ne, ver, w;
} e[N << 1];
int h[N], idx;
bool del[N];
int p[N], q[N], o[N], cp, cq, co;
list<int> Q;
int _Q[M];
bool b[K], ans[K];
void add(int x, int y, int z)
{
e[idx] = (Edge){h[x], y, z}, h[x] = idx++;
}
int get_si(int x, int fa)
{
if (del[x])
return 0;
int res = 1;
for (int i = h[x]; i != -1; i = e[i].ne)
if (e[i].ver != fa)
res += get_si(e[i].ver, x);
return res;
}
int get_wc(int x, int fa, int si, int &wc)
{
if (del[x])
return 0;
int sum = 1, mx = 0;
for (int i = h[x], y, t; i != -1; i = e[i].ne)
{
y = e[i].ver;
if (y == fa)
continue;
t = get_wc(y, x, si, wc);
mx = max(mx, t);
sum += t;
}
mx = max(mx, si - sum);
if (mx <= si / 2)
wc = x;
return sum;
}
void get_dis(int x, int fa, int dis)
{
if (del[x] || dis > K)
return;
q[++cq] = dis;
for (int i = h[x], y; i != -1; i = e[i].ne)
if (e[i].ver != fa)
get_dis(e[i].ver, x, dis + e[i].w);
}
void calc(int x)
{
if (del[x])
return;
get_wc(x, -1, get_si(x, -1), x);
del[x] = true;
cp = 0;
for (int i = h[x], y; i != -1; i = e[i].ne)
{
y = e[i].ver;
cq = co = 0;
get_dis(y, -1, e[i].w);
for (int j = 1; j <= cq; ++j)
{
for (int r : Q)
{
if (q[j] == r)
ans[r] = true, o[++co] = r;
if (r >= q[j] && b[r - q[j]])
ans[r] = true, o[++co] = r;
}
p[++cp] = q[j];
}
for (int j = 1; j <= cq; ++j)
b[q[j]] = true;
for (int j = 1; j <= co; ++j)
Q.remove(o[j]);
}
for (int j = 1; j <= cp; ++j)
b[p[j]] = false;
for (int i = h[x]; i != -1; i = e[i].ne)
calc(e[i].ver);
return;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i)
h[i] = -1;
idx = 0;
for (int i = 1, u, v, w; i < n; ++i)
{
scanf("%d%d%d", &u, &v, &w);
add(u, v, w), add(v, u, w);
}
for (int i = 1; i <= m; ++i)
{
scanf("%d", &_Q[i]);
Q.push_back(_Q[i]);
}
calc(1);
for (int i = 1; i <= m; ++i)
if (ans[_Q[i]])
printf("AYE\n");
else
printf("NAY\n");
return 0;
}

动态点分治(点分树)

还是先看模板题[HNOI2015]开店

既然叫点分树了,当然是要建一棵树,而这棵树要保证每一棵子树的根节点就是该子树的重心

建好树后,考虑如何解决询问,不难发现,一个节点 $u$ 最多属于 $\log n$ 棵子树,不妨设当前子树根节点为 $r$ :

  1. 若 $r \ne u$ ,考虑形如 $u \rightarrow r \rightarrow v$ 的路径有多少个,计入答案,然后进入 $u$ 所在子树递归
  2. 若 $u = r$ ,遍历当前子树所有点,计入答案,然后停止递归

由于每个点的度不大于3,直接在每个重心上开三个vector,记录每个子树的所有年龄和它到重心的距离,排好序后前缀和+二分即可,总的空间复杂度为 $O(n \log n)$ 时间复杂度为 $O(m \log^2 n)$ , $m$ 是询问的数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include <bits/stdc++.h>
#define LL long long
#define VS vector<Son>
using namespace std;
const int N = 1.5e5 + 5;
int n, m, A;
int h[N], idx;
struct Edge
{
int ne, ver, w;
} e[N << 1];
struct Father
{
int x, id;
LL dis;
};
struct Son
{
int age;
LL dis;
bool operator<(const Son &t) const
{
return age < t.age;
}
};
vector<Father> f[N];
VS s[N][3];
bool del[N];
int age[N];
void add(int x, int y, int z)
{
e[idx] = (Edge){h[x], y, z}, h[x] = idx++;
}
int get_si(int x, int fa)
{
if (del[x])
return 0;
int res = 1;
for (int i = h[x]; i != -1; i = e[i].ne)
if (e[i].ver != fa)
res += get_si(e[i].ver, x);
return res;
}
int get_wc(int x, int fa, int si, int &wc)
{
if (del[x])
return 0;
int sum = 1, mx = 0;
for (int i = h[x], t; i != -1; i = e[i].ne)
{
if (e[i].ver == fa)
continue;
t = get_wc(e[i].ver, x, si, wc);
mx = max(mx, t);
sum += t;
}
mx = max(mx, si - sum);
if (mx <= si / 2)
wc = x;
return sum;
}
void get_dis(int x, int fa, LL dis, int wc, int k, VS &p)
{
if (del[x])
return;
f[x].push_back((Father){wc, k, dis});
p.push_back((Son){age[x], dis});
for (int i = h[x], t; i != -1; i = e[i].ne)
if (e[i].ver != fa)
get_dis(e[i].ver, x, dis + e[i].w, wc, k, p);
}
void calc(int x)
{
if (del[x])
return;
get_wc(x, -1, get_si(x, -1), x);
del[x] = true;
for (int i = h[x], y, k = 0; i != -1; i = e[i].ne)
{
y = e[i].ver;
if (del[y])
continue;
VS &p = s[x][k];
p.push_back((Son){-1, 0}), p.push_back((Son){A + 1, 0}); //哨兵
get_dis(y, -1, e[i].w, x, k, p);
sort(p.begin(), p.end());
for (int i = 1; i < p.size(); ++i)
p[i].dis += p[i - 1].dis;
++k;
}
for (int i = h[x]; i != -1; i = e[i].ne)
calc(e[i].ver);
}
LL ask(int x, int l, int r)
{
LL res = 0;
for (Father &i : f[x])
{
int g = age[i.x];
if (g >= l && g <= r)
res += i.dis;
for (int j = 0; j < 3; ++j)
{
if (j == i.id)
continue;
VS &p = s[i.x][j];
if (p.empty())
continue;
int a = lower_bound(p.begin(), p.end(), (Son){l, -1}) - p.begin();
int b = lower_bound(p.begin(), p.end(), (Son){r + 1, -1}) - p.begin();
res += i.dis * (b - a) + p[b - 1].dis - p[a - 1].dis;
}
}
for (int i = 0; i < 3; ++i)
{
VS &p = s[x][i];
if (p.empty())
continue;
int a = lower_bound(p.begin(), p.end(), (Son){l, -1}) - p.begin();
int b = lower_bound(p.begin(), p.end(), (Son){r + 1, -1}) - p.begin();
res += p[b - 1].dis - p[a - 1].dis;
}
return res;
}
int main()
{
scanf("%d%d%d", &n, &m, &A);
for (int i = 1; i <= n; ++i)
h[i] = -1;
idx = 0;
for (int i = 1; i <= n; ++i)
scanf("%d", &age[i]);
for (int i = 1, u, v, w; i < n; ++i)
{
scanf("%d%d%d", &u, &v, &w);
add(u, v, w), add(v, u, w);
}
calc(1);
LL ans = 0;
int u, a, b, l, r;
while (m--)
{
scanf("%d%d%d", &u, &a, &b);
l = (a + ans) % A, r = (b + ans) % A;
if (l > r)
swap(l, r);
ans = ask(u, l, r);
printf("%lld\n", ans);
}
return 0;
}