Dyd's Blog

He who has a strong enough why can bear almost any how.

树套树

禁止套娃!

基本思想

顾名思义,树套树就是用一个树套进另一个树,即一个外层的树、一个内层的树。

简单版

思考如下问题:
维护一个长度为 $n$ 的序列,数列中的位置从左到右依次标号为 $1∼n$ ,其中需要提供以下操作:

  1. $1\ pos\ x$ ,将 $pos$ 位置的数修改为 $x$ 。

  2. $1\ a\ b\ x$ ,查询整数 $x$ 在区间 $[a,b]$ 内的前驱(前驱定义为小于 $x$ ,且最大的数)。

这个问题发现,第二个操作可以用 $set$ 完成,但 $set$ 仅支持查询整个区间,无法完成在区间 $[a,b]$ 上查询,所以我们考虑用一个线段树套在 $set$ 外层。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <bits/stdc++.h>
using namespace std;
const int N = 50000 + 5, M = N * 4; //线段树空间开4倍
const int INF = 0x3f3f3f3f;
int n, m;
struct Tree
{
int l, r;
//multiset是<set>库中一个类型,可以看成一个序列,
//插入删除数都能够在O(logn)的时间内完成,
//而且能保证序列中的数是有序的,而且序列中可以存在重复的数。
multiset<int> s;
void inint(int _l, int _r)
{
l = _l;
r = _r;
}
} tr[M];
int w[N];
void build(int u, int l, int r)
{
tr[u].inint(l, r);
tr[u].s.insert(-INF), tr[u].s.insert(INF); //防止越界
//将线段树该节点(即该段)内的所有点插入该节点下套的set中
for (int i = l; i <= r; ++i)
tr[u].s.insert(w[i]);
if (l == r)
return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void change(int u, int p, int x)
{
//删除w[p],注意不能写erase(w[p]),否则会把所有值为w[p]的都删除
tr[u].s.erase(tr[u].s.find(w[p]));
tr[u].s.insert(x);
if (tr[u].l == tr[u].r)
return;
int mid = (tr[u].l + tr[u].r) >> 1;
if (p <= mid)
change(u << 1, p, x);
else
change(u << 1 | 1, p, x);
}
int query(int u, int a, int b, int x)
{
//若tr[u]在完全在区间[a,b]内
if (tr[u].l >= a && tr[u].r <= b)
{
//注意it的类型,it是地址,类型得记
std::multiset<int>::iterator it = tr[u].s.lower_bound(x);
//it是"大于等于x的最小数",减一后就是"小于x的最大数"
--it;
return *it;
}
int mid = (tr[u].l + tr[u].r) >> 1, res = -INF;
if (a <= mid)
res = max(res, query(u << 1, a, b, x));
if (b > mid)
res = max(res, query(u << 1 | 1, a, b, x));
return res;
}
int main()
{
scanf("%d%d", &n, &m); //数列长度以及操作次数
for (int i = 1; i <= n; ++i)
scanf("%d", &w[i]);
build(1, 1, n);
while (m--)
{
int op, a, b, c, ans;
scanf("%d", &op);
if (op == 1)
{
scanf("%d%d", &a, &b);
change(1, a, b);
w[a] = b;
}
else
{
scanf("%d%d%d", &a, &b, &c);
ans = query(1, a, b, c);
printf("%d\n", (ans > -INF && ans < INF) ? ans : -1);
}
}
return 0;
}

正常版

问题:

问题加强如下(也可见P3380):
维护一个长度为 $n$ 的序列,数列中的位置从左到右依次标号为 $1∼n$ ,其中需要提供以下操作:

  1. $1\ l\ r\ x$ ,查询整数 $x$ 在区间 $[l,r]$ 内的排名。

  2. $2\ l\ r\ k$ ,查询区间 $[l,r]$ 内排名为 $k$ 的值。

  3. $3\ pos\ x$ ,将 $pos$ 位置的数修改为 $x$ 。

  4. $4\ l\ r\ x$ ,查询整数 $x$ 在区间 $[l,r]$ 内的前驱(前驱定义为小于 $x$ ,且最大的数)。

  5. $5\ l\ r\ x$ ,查询整数 $x$ 在区间 $[l,r]$ 内的后继(后继定义为大于 $x$ ,且最小的数)。

分析:

本题的四、五操作与简单版类似,可以用 $set$ 完成,但对于操作一、二, $set$ 却不支持排名,因为它不记录子树大小,so——我们只好手打平衡树。

而手打的平衡树可以完成操作一,但对于操作二,因为线段树会把区间 $[l,r]$ 分成最多 $\log N$ 个区间,每个区间内的平衡树互不联通,且通过简单的计算每个区间的答案并不能得到最终答案,所以我们思考其他方法——二分。
对于每个 $mid$ 用操作一找出它的排名,若排名小于等于 $k$ ,则答案小于等于 $mid$ ,反之则大于等于 $mid$ 。

由此,操作一、三、四、五的时间复杂度为 $\log^2 N$ ,操作二的时间复杂度为 $\log^3 N$ ,总时间复杂度为 $O(m\log^3 n)$ ,空间复杂度为 $O(N* 4* 2+N\log N)$ ,故空间约需 $50000* 4* 2+50000* 18 \le 1500000 $ 。

另外,一定要注意,有几个函数传的是实参!( $update$ 、 $splay$ 和 $insert$)。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#include <bits/stdc++.h>
using namespace std;
const int N = 15e5 + 5, INF = 2147483647;
struct Splay
{
int s[2], fa, size, v;
void inint(int _v, int _fa)
{
v = _v;
fa = _fa;
size = 1;
}
} tr[N];
struct Seg_Tree
{
int l, r;
int rt; //root,对应的Splay的根
} ttr[N]; //ttr表示外层树
int n, m, idx = 0; //idx:记录Splay的节点
int w[N];
/*------------------------Splay------------------------*/
void push_up(int x)
{
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}
void rotate(int x)
{
int y = tr[x].fa;
int z = tr[y].fa;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].fa = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].fa = y;
tr[x].s[k ^ 1] = y, tr[y].fa = x;
push_up(y), push_up(x);
}
void splay(int &root, int x, int k) //一定注意这里传实参
{
while (tr[x].fa != k)
{
int y = tr[x].fa;
int z = tr[y].fa;
if (z != k)
(tr[y].s[1] == x) ^ (tr[z].s[1] == y) ? rotate(x) : rotate(y);
rotate(x);
}
if (!k)
root = x;
}
//在splay中插入v
void insert(int &root, int v) //实参!
{
int u = root, p = 0;
while (u)
p = u, u = tr[u].s[v > tr[u].v];
u = ++idx;
if (p)
tr[p].s[v > tr[p].v] = u;
tr[u].inint(v, p);
splay(root, u, 0);
}
//在该splay中找到比v小的数的个数
int get_k(int root, int v)
{
int u = root, res = 0;
while (u)
{
if (tr[u].v < v)
res += tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
else
u = tr[u].s[0];
}
return res;
}
//在该splay中更新x为y
void update(int &root, int x, int y) //实参!
{
//删除x(忘了去看Splay)
int u = root;
while (u)
{
if (tr[u].v == x)
break;
if (tr[u].v < x)
u = tr[u].s[1];
else
u = tr[u].s[0];
}
splay(root, u, 0);
int l = tr[u].s[0], r = tr[u].s[1];
while (tr[l].s[1])
l = tr[l].s[1];
while (tr[r].s[0])
r = tr[r].s[0];
splay(root, l, 0), splay(root, r, l);
tr[r].s[0] = 0;
push_up(r), push_up(l);
//插入y
insert(root, y);
}
//在该splay中找到比v小的数中最大的
int get_pre(int root, int v)
{
int u = root, res = -INF;
while (u)
{
if (tr[u].v < v)
res = max(res, tr[u].v), u = tr[u].s[1];
else
u = tr[u].s[0];
}
return res;
}
//在该splay中找到比v大的数中最小的
int get_suc(int root, int v)
{
int u = root, res = INF;
while (u)
{
if (tr[u].v > v)
res = min(res, tr[u].v), u = tr[u].s[0];
else
u = tr[u].s[1];
}
return res;
}
/*------------------------Segment tree------------------------*/
void build(int u, int l, int r)
{
ttr[u].l = l, ttr[u].r = r;
insert(ttr[u].rt, -INF), insert(ttr[u].rt, INF); //防止越界
for (int i = l; i <= r; ++i)
insert(ttr[u].rt, w[i]);
if (l == r)
return;
int mid = (l + r) >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
//找到比x小的数的个数
int query(int u, int l, int r, int x)
{
if (ttr[u].l >= l && ttr[u].r <= r)
return get_k(ttr[u].rt, x) - 1; //因为多插入了-INF
int mid = (ttr[u].l + ttr[u].r) >> 1, res = 0;
if (l <= mid)
res += query(u << 1, l, r, x);
if (r > mid)
res += query(u << 1 | 1, l, r, x);
return res;
}
//将w[p]变为x
void change(int u, int p, int x)
{
update(ttr[u].rt, w[p], x);
if (ttr[u].l == ttr[u].r)
return;
int mid = (ttr[u].l + ttr[u].r) >> 1;
if (p <= mid)
change(u << 1, p, x);
else
change(u << 1 | 1, p, x);
}
//找到比x小的数中最大的
int query_pre(int u, int l, int r, int x)
{
if (ttr[u].l >= l && ttr[u].r <= r)
return get_pre(ttr[u].rt, x);
int mid = (ttr[u].l + ttr[u].r) >> 1, res = -INF;
if (l <= mid)
res = max(res, query_pre(u << 1, l, r, x));
if (r > mid)
res = max(res, query_pre(u << 1 | 1, l, r, x));
return res;
}
//找到比v大的数中最小的
int query_suc(int u, int l, int r, int x)
{
if (ttr[u].l >= l && ttr[u].r <= r)
return get_suc(ttr[u].rt, x);
int mid = (ttr[u].l + ttr[u].r) >> 1, res = INF;
if (l <= mid)
res = min(res, query_suc(u << 1, l, r, x));
if (r > mid)
res = min(res, query_suc(u << 1 | 1, l, r, x));
return res;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i)
scanf("%d", &w[i]);
build(1, 1, n);
while (m--)
{
int op, a, b, x;
scanf("%d", &op);
if (op == 1)
{
scanf("%d%d%d", &a, &b, &x);
printf("%d\n", query(1, a, b, x) + 1); //query是"比x小的数的个数",加1为x的排名
}
else if (op == 2)
{
scanf("%d%d%d", &a, &b, &x);
//二分答案
int l = 0, r = 1e8, ans;
while (l <= r)
{
int mid = (l + r) >> 1;
if (query(1, a, b, mid) + 1 <= x)
ans = mid, l = mid + 1;
else
r = mid - 1;
}
printf("%d\n", ans);
}
else if (op == 3)
{
scanf("%d%d", &a, &x);
change(1, a, x);
w[a] = x;
}
else if (op == 4)
{
scanf("%d%d%d", &a, &b, &x);
printf("%d\n", query_pre(1, a, b, x));
}
else
{
scanf("%d%d%d", &a, &b, &x);
printf("%d\n", query_suc(1, a, b, x));
}
}
return 0;
}

变式

问题:

AcWing2306
你需要维护 $n$ 个可重整数集,集合的编号从1到 $n$ 。
这些集合初始都是空集,有 $m$ 个操作:

  1. $1\ l\ r\ c$ :表示将 $c$ 加入到编号在 $[l,r]$ 内的集合中

  2. $2\ l\ r\ c$ :表示查询编号在 $[l,r]$ 内的集合的并集中,第 $c$ 大的数是多少。

注意可重集的并是不去除重复元素的,如 ${1,1,4} \cup {5,1,4}={1,1,4,5,1,4}$ 。

分析:

先尝试上题思路:外层树线段树,套内层平衡树。
但很快啊,啪的一下,我们就发现,这样的话区间修改非常麻烦,操作一(在区间 $[l,r]$ 中的每一个集合里加一个数 $c$ )的实现过于复杂。
于是,我们(听别人说)可以用权值线段树。

权值线段树:

  1. 将数值离散化

  2. 以数值(即权值)为节点建立线段树

  3. 线段数的每一个节点上,再用一棵线段树来维护权值在该节点内的所有数的下标

  4. 对于操作一,我们在外层数上找到包含 $c$ 的节点,最多有 $\log n$ 个,然后再个节点内层的线段树上找到区间 $[l,r]$ ,打上懒标记

  5. 对于操作二,我们用类似二分的思想,每次找一半区间(即外层线段数左右节点中的一个),求该区间下套的内层树上 $[l,r]$ 的和 $k$ ,若 $k \ge c$ 就进入(外层树的)右节点,否则进入左节点

时间复杂度为 $O(\log^2 N)$

标记持久化:

我们发现的区间增加也好、区间求和也好,都是在内层树上,而且只有这两个同类型(必须同类,例如有加有乘就不行)的操作,所以我们考虑更改 $sum$ 和 $add$ 的含义,使这些标记不必向下传(即省略 $push_down$)。
(以下定义都是针对内层树)
$sum$ :定义 $sum$ 为“只考虑本节点及以下节点懒标记的情况下的区间和”。
$add$ :定义 $add$ 为“本节点的所有子节点都要加上一个 $add$ ”。
对于区间求和,只需用本节点的 $sum$ 加上所有祖宗节点的 $add$ 和与区间长度的积即可,具体实现的话,只需在递归时顺便累加 $add$ 即可。

线段树的动态开点:

我们发现,直接开线段树套线段树,空间复杂度为 $O(N^2)$ ,对于本题来说无法接受,因此,我们自然想到,只开出用到的节点,即动态开点,空间复杂度为 $O(M\log N)$

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 50000 + 5, Nn = N * 17 * 17, M = N * 4;
struct Seg_Tree1
{
LL l, r, rt;
#define l(u) tr[u].l
#define r(u) tr[u].r
#define rt(u) tr[u].rt
} tr[N];
struct Seg_Tree2
{
LL l, r, sum, add; //由于动态开点,内层数的l、r表示的是节点下标
#define ll(u) ttr[u].l
#define rr(u) ttr[u].r
#define sum(u) ttr[u].sum
#define add(u) ttr[u].add
} ttr[Nn];
struct Question //为了离散化,需要保存问题
{
LL op, a, b, c;
#define op(u) q[i].op
#define a(u) q[i].a
#define b(u) q[i].b
#define c(u) q[i].c
} q[N];
vector<LL> que;
LL n, m;
LL idx = 0;
//离散化的对应
LL get(LL x)
{
return lower_bound(que.begin(), que.end(), x) - que.begin();
}
//intersection,求两区间的交集
LL inter(LL l1, LL r1, LL l2, LL r2)
{
return min(r1, r2) - max(l1, l2) + 1;
}
//更新内层树
void update(LL u, LL l, LL r, LL pl, LL pr)
{
sum(u) += inter(l, r, pl, pr);
if (l >= pl && r <= pr)
{
add(u)++;
return;
}
LL mid = (l + r) >> 1;
if (pl <= mid)
{
if (!ll(u))
ll(u) = ++idx; //动态开点
update(ll(u), l, mid, pl, pr);
}
if (pr > mid)
{
if (!rr(u))
rr(u) = ++idx;
update(rr(u), mid + 1, r, pl, pr);
}
}
LL get_sum(LL u, LL l, LL r, LL pl, LL pr, LL add)
{
if (l >= pl && r <= pr)
return sum(u) + (r - l + 1) * add;
LL mid = (l + r) >> 1, res = 0;
add += add(u);
if (pl <= mid)
{
if (ll(u))
res += get_sum(ll(u), l, mid, pl, pr, add);
else
res += inter(l, mid, pl, pr) * add; //如果未开ll(u),说明它未修改过
}
if (pr > mid)
{
if (rr(u))
res += get_sum(rr(u), mid + 1, r, pl, pr, add);
else
res += inter(mid + 1, r, pl, pr) * add;
}
return res;
}
void build(LL u, LL l, LL r)
{
l(u) = l, r(u) = r, rt(u) = ++idx;
if (l == r)
return;
LL mid = (l + r) >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
void change(LL u, LL l, LL r, LL x)
{
update(rt(u), 1, n, l, r);
if (l(u) == r(u))
return;
int mid = (l(u) + r(u)) >> 1;
if (x <= mid)
change(u << 1, l, r, x);
else
change(u << 1 | 1, l, r, x);
}
LL query(LL u, LL l, LL r, LL x)
{
if (l(u) == r(u))
return r(u);
LL mid = (l(u) + r(u)) >> 1;
LL k = get_sum(rt(u << 1 | 1), 1, n, l, r, 0);
if (k >= x)
return query(u << 1 | 1, l, r, x);
return query(u << 1, l, r, x - k);
}
int main()
{
scanf("%lld%lld", &n, &m);
for (LL i = 0; i < m; ++i) //为了配合vector,循环也都从0开始
{
scanf("%lld%lld%lld%lld", &op(i), &a(i), &b(i), &c(i));
if (op(i) == 1)
que.push_back(c(i));
}
//排序,去重
sort(que.begin(), que.end());
que.erase(unique(que.begin(), que.end()), que.end());

build(1, 0, que.size() - 1);
for (LL i = 0; i < m; ++i)
{
if (op(i) == 1)
change(1, a(i), b(i), get(c(i)));
else
printf("%lld\n", que[query(1, a(i), b(i), c(i))]);
}
return 0;
}