Dynamic Dynamic Programming 2333333.

Description

一类带修改的树上DP问题。例如【模板】动态DP, NOIP2018 保卫王国.

Solution

以模板 - 树上最大权独立集为例.

我们熟知的链上DP方程是:

$f_{x,0} = \max(f_{s,0}, f_{s,1}), f_{x,1} = v_x + f_{s,0}$.

如果每次修改都重做DP的话, 总的时间复杂度就是$O(nm)$, 我们要优化. 考虑这种加法和取max的操作能否对应成一种矩阵的运算. 如果定义矩阵运算 $\oplus$, 有$A \oplus B(i,j) = \begin{align}\max_{k=1}^{m}{\{A(i,k) + B(k,j)\}}\end{align}$

那么我们的转移矩阵就可以写成: $\begin{bmatrix}0 & 0 \ V_x & -\infty\end{bmatrix}$.

然后考虑树的问题在哪, 一个节点会对应多个孩子, 也就是.

$f_{x,0} = f_{x,0} + \max(f_{s,0}, f_{s,1}), f_{x,1} = f_{x,1} + f_{s,0}$.

如何维护矩阵呢?$\begin{bmatrix}f_{x,0} & f_{x,0} \ f_{x,1} & -\infty\end{bmatrix}$ .所以这个时候把转移矩阵和答案放一起就可以了, 取的时候直接取第一列.

我们可以每次修改从当前点到根的一条链的转移矩阵, 线段树维护, 每次查询即可. 然而这样还是可以卡成$O(nm)$. 来考虑更高效的转移方式. 使用重链剖分.

我们利用跳重链$\log$次的性质来维护转移矩阵, 每次考虑节点所在重链顶对于父亲的影响.

首先我们对每个构造一个只做了轻儿子DP的节点矩阵, 这样查询一条重链就做完了全部的DP. 先备份重链顶的矩阵. 每次修改完当前节点后, 重新查询重链顶, 就可以快速得到重链顶的状态, 这时就可以增量更新父亲的矩阵了(+= 新矩阵的结果 - 旧矩阵的结果). 接着跳上去, 对于父亲也是同样的过程. 增量更新是因为矩阵是所有轻儿子混在一起的, 无法提出一个单独的位置.

然后就做完了. 答案就是查询重链顶$1$.

NOIP2018的题类似, 有: 最小权覆盖 = 全集 - 最大独立集, 所以也是一样做的. 注意强制选/不选对应着在独立集中强制不选/选, 然后用$\infty$ 和 $-\infty$ 调整一下就行了.

贴个模板代码吧.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

const int maxn = 1e+5 + 5;
const int inf = 0x3f3f3f3f;

//------ GRAPH STARTS----------
struct edge {
int to, nxt;
}e[maxn << 1];
int n, m, lnk[maxn], ptr, siz[maxn], mxson[maxn];
int rev[maxn], f[maxn], top[maxn], bot[maxn];
int dfn[maxn], cnt;
ll dp[maxn][2], v[maxn];

inline void InitNewEdge(int bgn, int end) {
e[++ptr] = (edge){end, lnk[bgn]};
lnk[bgn] = ptr;
}
inline int ReadInt() {
register int x = 0, f = 0, c = getchar();
while (!isdigit(c)) {
if (c == '-') f = 1;
c = getchar();
}
while (isdigit(c)) {
x = x * 10 + (c ^ 48);
c = getchar();
}
return f ? -x : x;
}
void dfs(int x, int fa) {
siz[x] = 1;
f[x] = fa;
dp[x][1] = v[x];
for (int p = lnk[x]; p; p = e[p].nxt) {
int y = e[p].to;
if (y == fa) continue;
dfs(y, x);
dp[x][1] += dp[y][0];
dp[x][0] += max(dp[y][1], dp[y][0]);
siz[x] += siz[y];
if (siz[y] > siz[mxson[x]]) mxson[x] = y;
}
}
void dfs2(int x, int init) {
top[x] = init;
dfn[x] = ++cnt; rev[cnt] = x;
if (!mxson[x]) {
bot[x] = x;
return;
}
dfs2(mxson[x], init); bot[x] = bot[mxson[x]];
for (int p = lnk[x]; p; p = e[p].nxt) {
int y = e[p].to;
if (y == f[x] || y == mxson[x]) continue;
dfs2(y, y);
}
}
//-------GRAPH ENDS------------

//-------MATRIX STARTS---------
struct matrix {
ll a[2][2];
matrix(const ll &arg_1 = 0, const ll &arg_2 = 0,
const ll &arg_3 = 0, const ll &arg_4 = 0) {
a[0][0] = arg_1;
a[0][1] = arg_2;
a[1][0] = arg_3;
a[1][1] = arg_4;
}
ll* operator[](int x) {
return a[x];
}
inline matrix operator*(matrix rhs) {
matrix ret;
ret[0][0] = max(a[0][1] + rhs[1][0], a[0][0] + rhs[0][0]);
ret[0][1] = max(a[0][1] + rhs[1][1], a[0][0] + rhs[0][1]);
ret[1][0] = max(a[1][0] + rhs[0][0], a[1][1] + rhs[1][0]);
ret[1][1] = max(a[1][1] + rhs[1][1], a[1][0] + rhs[0][1]);
return ret;
}
};
matrix node[maxn << 2], pt[maxn];
void build(int cur, int l, int r) {
if (l == r) {
int R = rev[l];
ll f0 = 0, f1 = v[R];
for (int p = lnk[R]; p; p = e[p].nxt) {
int y = e[p].to;
if (y == f[R] || y == mxson[R]) continue;
f0 += max(dp[y][1], dp[y][0]);
f1 += dp[y][0];
}
node[cur] = pt[l] = matrix(f0, f0, f1, -inf);
return;
}
int mid = (l + r) >> 1;
build(cur << 1, l, mid);
build(cur << 1|1, mid + 1, r);
node[cur] = node[cur << 1] * node[cur << 1|1];
}
void modify(int cur, int l, int r, int p) {
if (l == r) {
node[cur] = pt[l];
return;
}
int mid = (l + r) >> 1;
if (p <= mid) modify(cur << 1, l, mid, p);
else modify(cur << 1|1, mid + 1, r, p);
node[cur] = node[cur << 1] * node[cur << 1|1];
}
matrix query(int cur, int l, int r, int L, int R) {
if (L == l && r <= R) return node[cur];
int mid = (l + r) >> 1;
if (R <= mid) return query(cur << 1, l, mid, L, R);
else if (L > mid) return query(cur << 1|1, mid + 1, r, L, R);
else return query(cur << 1, l, mid, L, mid) *
query(cur << 1|1, mid + 1, r, mid + 1, R);
}
//-------MATRIX ENDS-----------

//------SOLUTION STARTS--------
matrix GetChain(int x) {
return query(1, 1, n, dfn[top[x]], dfn[bot[x]]);
}
void modify(int x, int val) {
pt[dfn[x]][1][0] += val - v[x];
v[x] = val;
while (x) {
matrix Origin = GetChain(top[x]);
modify(1, 1, n, dfn[x]);
matrix New = GetChain(top[x]);
x = f[top[x]];
if (!x) break;
pt[dfn[x]][0][0] += max(New[0][0], New[1][0]) -
max(Origin[0][0], Origin[1][0]);
pt[dfn[x]][0][1] = pt[dfn[x]][0][0];
pt[dfn[x]][1][0] += New[0][0] - Origin[0][0];
}
}
//--------SOLUTION ENDS--------

#define VertexSum n
#define OptionSum m

int main(int argc, char const *argv[]) {
VertexSum = ReadInt();
OptionSum = ReadInt();
for (int i = 1; i <= VertexSum; ++i)
v[i] = ReadInt();
typedef int Vertex_t;
Vertex_t VertexFrom, VertexTo;
for (int i = 1; i < VertexSum; ++i) {
VertexFrom = ReadInt();
VertexTo = ReadInt();
InitNewEdge(VertexFrom, VertexTo);
InitNewEdge(VertexTo, VertexFrom);
}
dfs(1, 0);
dfs2(1, 1);
build(1, 1, n);
int x, y;
while (m--) {
x = ReadInt();
y = ReadInt();
modify(x, y);
matrix ans = GetChain(1);
printf("%lld\n", max(ans[0][0], ans[1][0]));
}
return 0;
}