1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| #include <cstdio> #include <cstring> using LL = long long; const int N = 2e5 + 100, K = 5 + 5; int n, num, h[N], idx; struct Edge{ int ne, ver; } e[N << 1]; int ct[N][K]; LL d[N][K], ans; void add(int x, int y){ e[idx] = {h[x], y}, h[x] = idx++; } void plus(int x, int y) { if (!x || !y) return ; ++ct[x][1], ++d[x][1]; for (int j = 2; j <= num; ++j) { ct[x][j] += ct[y][j - 1]; d[x][j] += d[y][j - 1]; } ct[x][1] += ct[y][num]; d[x][1] += d[y][num] + ct[y][num]; } void subt(int x, int y) { if (!x || !y) return ; --ct[x][1], --d[x][1]; for (int j = 2; j <= num; ++j) { ct[x][j] -= ct[y][j - 1]; d[x][j] -= d[y][j - 1]; } ct[x][1] -= ct[y][num]; d[x][1] -= d[y][num] + ct[y][num]; } void dp1(int x, int fa){ for (int i = h[x], y; ~i; i = e[i].ne) if ((y = e[i].ver) != fa) dp1(y, x), plus(x, y); } void dp2(int x, int fa) { subt(fa, x), plus(x, fa); for (int j = 1; j <= num; ++j) ans += d[x][j]; for (int i = h[x]; ~i; i = e[i].ne) if (e[i].ver != fa) dp2(e[i].ver, x); subt(x, fa), plus(fa, x); } int main() { std::memset(h, -1, sizeof h), idx = 0; scanf("%d %d", &n, &num); for (int i = 1, u, v; i < n; ++i) { scanf("%d %d", &u, &v); add(u, v), add(v, u); } dp1(1, 0), dp2(1, 0); printf("%lld\n", ans >> 1); return 0; }
|