Falling Factorial, Tree DP

Description

Given a tree, for every node $x$ compute $\begin{align}\sum_{i=1}^{n}{dist(x,i)^k}\end{align}$.

$n \le 50000, k \le 500$.

Solution

When we solve $k = 1$ or $k = 2$, we can simply use trick of 2-times DFS, and maintain both sum of distance and sum of square of distance.

However, when problem becomes harder, the time complexity of this method will get TLE, for its $O(nk^2)$ time to maintain sum of powers using binomial theorem.

There’re some general ways to deal these sum of power problems, one of which is using falling factorial.

  • Definition : Falling factorial of $x$ at $k$, denoted as $x^{\underline{k}}$, equals to $x(x-1)\cdots (x-k+1)$, which is also known as $P_x ^ k$, where $P$ means permutation.

And we have: $x^k = \sum_{i=0}^{k}{S(k,i)\times x^{\underline{i}}}$, where $S(k, i)$ is the Stirling Number of the second kind.

As for $S$, we have a recurrence relation that $S(i,j) = S(i-1, j - 1)+ j \times S(i-1,j) $, we can pre-process all $S$ we need in $O(k^2)$ time.

Then let’s see how to maintain falling factorial in $O(k)$ time.

We have a identity of falling factorial that $(x+1)^{\underline{i}} = (x+1)x^{\underline{i-1}} = (x - i + 1 + i)x^{\underline{i-1}} = x^{\underline{i}} + i\times x^{\underline{i-1}}$.

Another way to prove it is using permutation’s properties.

Think about when we transfer answer from $son_x$ to $x$ or reversely, what we do is actually add $1$ to all numbers in a set.

Thus for case that $son_x \rightarrow x$, we have: $ans_x^{k} = ans_x^k + ans_{son}^{k} + k \times ans_{son}^{k-1}$, for all distance in the subtree of $x$ increased.

As for $fa_x \rightarrow x$, it’s all number in father’s answer increasing, but it is larger than what we need. So we should minus numbers $+2$ in subtree of $x$, for these numbers get $+1$ when calculating father’s answer and this time they get a $+1$ again. Then add original answer in subtree of $x$. That is:

$(all_{fa}+1)^{\underline{k}} - (all_x + 2)^{\underline{k}} + ans_x^k$. Split the expression out we get:

$ans_{fa}^k + k \times ans_{fa}^{k-1} - 2k ans_{x}^{k-1} - k(k-1)ans_x^{k-2}$ for all $k \ge 2$.

When $k = 1$ it’s $ans_x^1 = ans_{fa}^1 + ans_{fa}^0 - 2ans_{x}^{0}$ , and $ans_x^0 = ans_{fa}^0$ for $k = 0$.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <bits/stdc++.h>

using namespace std;

const int maxn = 50005;
const int maxk = 505;
const int mod = 10007;

struct edge {
int to, nxt;
}e[maxn << 1];
int n, k, ptr, lnk[maxn];
int s[maxk][maxk], downp[maxn][maxk], dep[maxn], ans[maxn][maxk];

void init() {
s[0][0] = 1;
for (int i = 1; i <= 500; ++i)
for (int j = 1; j <= 500; ++j) {
s[i][j] = (s[i - 1][j - 1] + j * s[i - 1][j] % mod) % mod;
}
}
inline int rd() {
register int x = 0, c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
return x;
}
inline void add(int bgn, int end) {
e[++ptr] = (edge){end, lnk[bgn]};
lnk[bgn] = ptr;
}
void dfs(int x, int fa) {
ans[x][0] = 1;
for (int i = 1; i <= k; ++i) ans[x][i] = 0;
for (int p = lnk[x]; p; p = e[p].nxt) {
int y = e[p].to;
if (y == fa) continue;
dfs(y, x);
ans[x][0] = (ans[x][0] + ans[y][0]) % mod;
for (int i = 1; i <= k; ++i)
ans[x][i] = (ans[x][i] + ans[y][i] + i * ans[y][i - 1]) % mod;
}
}
void dfs2(int x, int fa) {
if (fa) {
for (int i = k; i >= 2; --i) {
ans[x][i] = (ans[fa][i] + i * ans[fa][i - 1] % mod -
2ll * i * ans[x][i - 1] % mod + mod -
1ll * i * (i - 1) * ans[x][i - 2] % mod + mod) % mod;
}
ans[x][1] = (ans[fa][1] + ans[fa][0] - 2 * ans[x][0] % mod + mod) % mod;
ans[x][0] = ans[fa][0];
}
for (int p = lnk[x]; p; p = e[p].nxt) {
int y = e[p].to;
if (y == fa) continue;
dfs2(y, x);
}
}

int main() {
init();
int T = rd();
while (T--) {
ptr = 0; memset(lnk, 0, sizeof lnk);
n = rd(); k = rd();
int u, v;
for (int i = 1; i < n; ++i) {
u = rd(); v = rd();
add(u, v);
add(v, u);
}
dfs(1, 0);
dfs2(1, 0);
for (int i = 1; i <= n; ++i) {
int sum = 0;
for (int j = 0; j <= k; ++j)
sum = (sum + ans[i][j] * s[k][j] % mod) % mod;
printf("%d\n", sum);
}
}
return 0;
}