Dyd's Blog

He who has a strong enough why can bear almost any how.

luoguP4072 [SDOI2016]征途

被 K-D Tree 虐了一天后继续来 dp 受苦

征途

题意

给定 $n$ 个数,分成 $m$ 组,使得每组之和构成的数组方差最小,输出方差 $\times m^2$ , $n \le 3000$

做题

考虑把方差转化推式子:
$$
\begin{aligned}
s^2
& = \frac{\sum_{i = 1}^{m} (v_i - \overline{v})^2}{m} \\
& = \frac{m (\overline{v})^2 - 2 \overline{v} \sum_{i = 1}^{m} v_i + \sum_{i = 1}^{m} v_i^2}{m} \\
又有:\overline{v} & = \frac{\sum_{i = 1}^{m} v_i}{m} \\
代入得 s^2
& = \frac{m (\frac{\sum_{i = 1}^{m} v_i}{m})^2 - 2 \frac{\sum_{i = 1}^{m} v_i}{m} \sum_{i = 1}^{m} v_i + \sum_{i = 1}^{m} v_i^2}{m} \\
& = \frac{\frac{(\sum_{i = 1}^{m} v_i)^2}{m} - 2 \frac{(\sum_{i = 1}^{m} v_i)^2}{m} + \sum_{i = 1}^{m} v_i^2}{m} \\
故 Ans
& = s^2 \times m^2 \\
& = m \sum_{i = 1}^{m} v_i^2 -(\sum_{i = 1}^{m} v_i)^2 \\
& = m \sum_{i = 1}^{m} v_i^2 -(\sum_{i = 1}^{n} x_i)^2 \\
\end{aligned}
$$
发现减号右边的值是恒定的,现在要最小化 $\sum_{i = 1}^{m} v_i^2$

考虑 dp ,设 d[i][j] 表示“把前 $i$ 个数分成 $j$ 段的最小平方和”,转移就预处理前缀和, $O(n)$ 枚举最后一段的长度,时间复杂度为 $O(n^2 m)$ , TLE $80pts$

观察 dp 方程: $d[i][j] = min(d[k][j - 1] + (sum[i] - sum[k])^2)$ ,考虑斜率优化,设 $t$ 比 $k$ 优,则:
$$
\begin{aligned}
d[t][j - 1] + (sum[i] - sum[t])^2 & < d[k][j - 1] + (sum[i] - sum[k])^2 \\
d[t][j - 1] + sum[t]^2 - d[k][j - 1] - sum[k]^2 & < 2 sum[i] (sum[t] - sum[k]) \\
\frac{(d[t][j - 1] + sum[t]^2) - (d[k][j - 1] + sum[k]^2)}{(sum[t] - sum[k])} & < 2 sum[i]
\end{aligned}
$$
换句话说,把二元组 $(sum[x], d[x][j - 1] + sum[x]^2)$ 看作平面上的点,则点 $(sum[t], d[t][j - 1] + sum[t]^2)$ 比 $(sum[k], d[k][j - 1] + sum[k]^2)$ 优的充要条件(因为以上推导显然可以反向)是两点连线( $k \to t$ )的斜率小于 $2 sum[i]$ ,则启发我们将式子化为:
$$
\begin{aligned}
(d[k][j - 1] + sum[k]^2) & = (2sum[i]) \times (sum[k]) + (d[i][j] - sum[i]^2) \\
y & = k \times x + b
\end{aligned}
$$
要想斜率优化,还要保证 $k, x$ 单调递增,而它们都是前缀和,单调性显然,于是就可以(痛苦)快乐的用斜率优化,维护一个凸包,把第二维滚动压掉(只会从 $j - 1$ 到 $j$ ),以第二维为最外层循环,枚举第一维,用单调队列维护凸包,转移时直接取队头,时间复杂度为 $O(mn)$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <bits/stdc++.h>
#define LL long long
#define DB double
using namespace std;
const int N = 3000 + 5;
int n, m, l, r, o = 0;
LL a[N], sum[N], d[N][2], q[N];
DB get_k(int x, int y, int k)
{
return (DB)((d[y][k] + sum[y] * sum[y]) - (d[x][k] + sum[x] * sum[x])) / (DB)(sum[y] - sum[x]);
}
int main()
{
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; ++i) scanf("%lld", &a[i]), sum[i] = sum[i - 1] + a[i];
for (int i = 1; i <= n; ++i) d[i][o] = sum[i] * sum[i]; //这里不能初始化为0,所以先把m=1算出来
o ^= 1;
for (int i = 2, j; i <= m; ++i, o ^= 1)
for (j = i, l = r = 1, q[l] = i - 1; j <= n; ++j)
{
while (l < r && get_k(q[l], q[l + 1], o ^ 1) < 2 * sum[j]) ++l;
d[j][o] = d[q[l]][o ^ 1] + (sum[j] - sum[q[l]]) * (sum[j] - sum[q[l]]);
while (l < r && get_k(q[r - 1], q[r], o ^ 1) > get_k(q[r], j, o ^ 1)) --r;
q[++r] = j;
}
printf("%lld", d[n][o ^ 1] * m - sum[n] * sum[n]);
return 0;
}