Dyd's Blog

He who has a strong enough why can bear almost any how.

K-D Tree

三维生物瑟瑟发抖

K-D Tree

K-D Tree(k-Dimension Tree,k 维树),是一种高效处理 $k$ 维空间信息的(JO级生物)数据结构,在节点数远大于 $2^k$ 时效率很高

构建

K-D Tree 具有二叉搜索树的形态,通过以下伪代码可以将 $n$ 个 $k$ 维的点建出一棵K-D Tree :

1
2
3
4
5
6
7
function build (点集 S)
if (S 只有一个点) return 该点
选择一个维度 k ,选择一个切割点 a
以点 a 的第 k 维为标准,所有第 k 维小于 a 的点归入集合 L ,其余归入 R
以点 a 为父节点,用 L 建左子树, R 建右子树
维护一些信息
end function

可以发现这个树的结构取决于选择的维度和切割点,我们按以下标准选择:

  1. 选择的维度要满足其内部点的分布的差异度最大,即每次选择的切割维度是方差最大的维度
  2. 每次在维度上选择切割点时选择该维度上的中位数,这样可以保证每次分成的左右子树大小尽量相等

可以发现这样建出的 K-D Tree 高度最大为 $O(\log n)$

考虑建树的时间,每次找中位数时,用 sort 总时间是 $O(n \log^2 n)$ 的,这显然不必要,因为我们不必每次给整个序列排序,只要中位数在正确位置,且左边都小于它,右边大于它即可,在 algorithm 库中的 nth_element() 函数可以支持该操作,它的均摊复杂度为 $O(n)$ 于是建树就变了 $O(n \log n)$

插入/删除

数据结构大部分都要求支持插入和删除,但注意到 K-D Tree 具有二叉搜索树的形态,可它又不支持旋转(或者说很难支持旋转),而 FHQ Treap 的随机优先级思想也不能保证其复杂度,我们考虑用替罪羊树的重构思想

引入重构常数 $\alpha$ ,对于节点 $x$ 如果它的一个子树的结点数占比大于 $\alpha$ ,或者未删除的结点数在以 $x$ 为根的子树中的占比小于 $\alpha$ 时,我们就重构它

在插入一个点时,先根据记录的分割维度和分割点判断应该继续插入到左子树还是右子树,如果到达了空结点,新建一个结点代替这个空结点,成功插入结点后回溯插入的过程,维护结点的信息,如果发现当前的子树不平衡,则重构当前子树

在删除一个点时,先找到它,然后打上懒标记即可

类似于替罪羊树,带重构的 K-D Tree 的树高仍然是 $O(\log n)$ 的,当然,最好是把操作离线了

例题

T1

简单题

$20 MB$ 卡掉树套树,强制在线卡了 CDQ ,于是用 K-D Tree (略微卡常)

对于修改,直接删除再插入即可,对于询问,记录子树每一维的最大和最小值,可以证明,这样查询 $k$ 维的最坏时间为 $O(n^{1 - \frac{1}{k}})$ ,对于本题,就是 $O(\sqrt{n})$ ,本题可以插入重复元素,然后就不需要删除

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include <bits/stdc++.h>
#define DB double
#define IL inline
namespace Fast
{
const int L = (1 << 20) + 5;
char buf[L], out[L], *iS, *iT;
int l = 0;
#define gh() (iT == iS ? iT = (iS = buf) + fread(buf, 1, L, stdin), (iT == iS ? EOF : *iS++) : *iS++)
template<class T>
IL void read(T &x)
{
x = 0;
char ch = gh(), t = 0;
while (ch < '0' || ch > '9')
t |= ch == '-', ch = gh();
while (ch >= '0' && ch <= '9')
x = x * 10 + (ch ^ 48), ch = gh();
if (t)
x = -x;
}
IL void flus()
{
fwrite(out, 1, l, stdout);
l = 0;
}
IL void putc(char x)
{
out[l++] = x;
if (l == L - 5)
flus();
}
template<class T>
IL void write(T x)
{
if (x < 0)
putc('-'), x = -x;
if (x > 9)
write(x / 10);
out[l++] = x % 10 + 48;
if (l == L - 5)
flus();
}
}
using Fast::flus;
using Fast::putc;
using Fast::read;
using Fast::write;
using namespace std;
const int M = 2e5 + 5, D = 2;
int n;
struct Point
{
int x[D], w;
} ;
namespace KDT //K-D Tree
{
const DB a = 0.725;
int rt, top, rub[M], cur, tot; //rub:回收空间
Point p[M];
struct Node
{
int mn[D], mx[D], sum, lc, rc, si, k;
Point p;
} tr[M];
#define mn(x) tr[(x)].mn
#define mx(x) tr[(x)].mx
#define sum(x) tr[(x)].sum
#define lc(x) tr[(x)].lc
#define rc(x) tr[(x)].rc
#define si(x) tr[(x)].si
#define k(x) tr[(x)].k
#define p(x) tr[(x)].p
IL int newnode()
{
if (top) return rub[top--];
return ++cur;
}
IL void up(int u)
{
int ls = lc(u), rs = rc(u);
for (int i = 0; i < D; ++i)
{
mn(u)[i] = mx(u)[i] = p(u).x[i];
if (ls) mn(u)[i] = min(mn(u)[i], mn(ls)[i]), mx(u)[i] = max(mx(u)[i], mx(ls)[i]);
if (rs) mn(u)[i] = min(mn(u)[i], mn(rs)[i]), mx(u)[i] = max(mx(u)[i], mx(rs)[i]);
}
sum(u) = sum(ls) + sum(rs) + p(u).w, si(u) = si(ls) + si(rs) + 1;
}
IL int build(int l, int r, int k)
{
if (l > r) return 0;
int mid = l + r >> 1, u = newnode();
nth_element(p + l, p + mid, p + r + 1, [&](Point a, Point b){ return a.x[k] < b.x[k]; });
k(u) = k, p(u) = p[mid], lc(u) = build(l, mid - 1, k ^ 1), rc(u) = build(mid + 1, r, k ^ 1);
return up(u), u;
}
IL void get_p(int u)
{
if (!u) return ;
get_p(lc(u)), p[++tot] = p(u), rub[++top] = u, get_p(rc(u));
}
IL void check(int &u)
{
if (si(u) * a < max(si(lc(u)), si(rc(u)))) tot = 0, get_p(u), u = build(1, tot, k(u));
}
IL void ins(int &u, Point x)
{
if (!u)
{
u = newnode();
lc(u) = rc(u) = k(u) = 0, p(u) = x;
return up(u);
}
if (x.x[k(u)] <= p(u).x[k(u)]) ins(lc(u), x);
else ins(rc(u), x);
up(u), check(u);
}
IL bool in(int mx[D], int mn[D], int l[D], int r[D])
{
for (int i = 0; i < D; ++i) if (r[i] < mx[i] || l[i] > mn[i]) return false;
return true;
}
IL bool out(int mx[D], int mn[D], int l[D], int r[D])
{
for (int i = 0; i < D; ++i) if (l[i] > mx[i] || r[i] < mn[i]) return true;
return false;
}
IL int ask(int u, int l[D], int r[D])
{
if (!u || out(mx(u), mn(u), l, r)) return 0;
if (in(mx(u), mn(u), l, r)) return sum(u);
return (in(p(u).x, p(u).x, l, r) ? p(u).w : 0) + ask(lc(u), l, r) + ask(rc(u), l, r);
}
}
int main()
{
read(n);
for (int op, x[D], y[D], last = 0; read(op), op != 3; )
{
read(x[0]), read(x[1]), read(y[0]), x[0] ^= last, x[1] ^= last, y[0] ^= last;
if (op == 1) KDT::ins(KDT::rt, {x[0], x[1], y[0]});
else read(y[1]), y[1] ^= last, write(last = KDT::ask(KDT::rt, x, y)), putc('\n');
}
return flus(), 0;
}

T2

TATT

经典高维偏序, bitset 是 $O(\frac{n^2 k}{w})$ , CDQ 是 $O(n \log^{k - 1} n)$ , K-D Tree 是 $O(n * n^{1 - \frac{1}{k - 1}})$ (这里 $k - 1$ 是因为我们可以先排序变成 $k - 1$ 维偏序),仔细一算,K-D Tree 最优(比它们小了一个 $0$ ),我们打 K-D Tree

打的时候注意指针的实参是 *&

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <bits/stdc++.h>
using DB = double;
const int N = 5e4 + 100, K = 3, INF = 0x3f3f3f3f;
int n, ans = 0;
struct Point
{
int x[K + 1], w;
int& operator [] (int id){ return x[id]; }
} a[N];
namespace KDT
{
const DB alpha = 0.75;
struct Node
{
int si, mxw;
Point dat, mn, mx;
Node *ch[2];
void up()
{
si = ch[0]->si + ch[1]->si + 1;
mxw = std::max(std::max(ch[0]->mxw, ch[1]->mxw), dat.w);
for (int i = 0; i < K; ++i)
{
mn[i] = std::min(std::min(ch[0]->mn[i], ch[1]->mn[i]), dat[i]);
mx[i] = std::max(std::max(ch[0]->mx[i], ch[1]->mx[i]), dat[i]);
}
}
bool in(Point l, Point r)
{
for (int i = 0; i < K; ++i) if (l[i] > mn[i] || r[i] < mx[i]) return false;
return true;
}
bool out(Point l, Point r)
{
for (int i = 0; i < K; ++i) if (l[i] > mx[i] || r[i] < mn[i]) return true;
return false;
}
bool at(Point l, Point r)
{
for (int i = 0; i < K; ++i) if (dat[i] < l[i] || dat[i] > r[i]) return false;
return true;
}
bool chk(){ return ch[0]->si <= si * alpha && ch[1]->si <= si * alpha; }
} pool[N], *tail, *rt, *rub[N], *null;
Point dw, upd[N];
int res, top = 0, cnt;
void init()
{
tail = pool, null = ++tail;
for (int i = 0; i < K; ++i)
{
null->mn[i] = INF;
dw[i] = null->mx[i] = -INF;
}
null->ch[0] = null->ch[1] = null;
rt = null;
}
Node* newnd(Point x)
{
Node *p = top ? rub[top--] : ++tail;
p->ch[0] = p->ch[1] = null;
p->dat = p->mn = p->mx = x;
p->mxw = x.w, p->si = 1;
return p;
}
void walk(Node *u)
{
if (u == null) return ;
walk(u->ch[0]);
upd[++cnt] = u->dat;
rub[++top] = u;
walk(u->ch[1]);
}
Node* bd(int l, int r, int d)
{
if (l > r) return null;
int mid = (l + r) >> 1;
std::nth_element(upd + l, upd + mid, upd + r + 1, [&](Point x, Point y)
{
if (x[d] ^ y[d]) return x[d] < y[d];
for (int i = 0; i < K; ++i) if (x[i] ^ y[i]) return x[i] < y[i];
return false;
});
Node* u = newnd(upd[mid]);
if (l == r) return u;
u->ch[0] = bd(l, mid - 1, (d + 1) % K);
u->ch[1] = bd(mid + 1, r, (d + 1) % K);
return u->up(), u;
}
void reb(Node *&u){ cnt = 0, walk(u), u = bd(1, cnt, 0); }
void ask(Node *u, Point l, Point r)
{
if (u == null || u->out(l, r)) return;
if (u->in(l, r)) return void(res = std::max(u->mxw, res));
res = std::max(u->at(l, r) ? u->dat.w : 0, res);
if (u->ch[0]->mxw > res) ask(u->ch[0], l, r);
if (u->ch[1]->mxw > res) ask(u->ch[1], l, r);
}
int ask(Point up)
{
res = 0;
ask(rt, dw, up);
return res;
}
Node** ins(Node *&u, Point x, int d)
{
if (u == null) return u = newnd(x), &null;
Node** bad = ins(u->ch[u->dat[d] < x[d]], x, (d + 1) % K);
u->up();
if (!u->chk()) bad = &u;
return bad;
}
void ins(Point x)
{
Node **bad = ins(rt, x, 0);
if (*bad != null) reb(*bad);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
for (int j = 0; j <= K; ++j) scanf("%d", &a[i][j]);
std::sort(a + 1, a + n + 1, [&](Point x, Point y)
{
if (x[3] ^ y[3]) return x[3] < y[3];
for (int i = 0; i < K; ++i) if (x[i] ^ y[i]) return x[i] < y[i];
return false;
});
KDT::init();
for (int i = 1; i <= n; ++i)
{
ans = std::max(ans, a[i].w = KDT::ask(a[i]) + 1);
KDT::ins(a[i]);
}
printf("%d\n", ans);
return 0;
}