区间DP(并不用分治)

Table of Contents

  1. Description
  2. Solution

Description

有$n$家洗车店从左往右排成一排,每家店都有一个正整数价格$p[i]$。

有$m$个人要来消费,第$i$个人会驶过第$a[i]$个开始一直到第$b[i]$个洗车店,且会选择这些店中最便宜的一个进行一次消费。但是如果这个最便宜的价格大于$c[i]$,那么这个人就不洗车了。

请给每家店指定一个价格,使得所有人花的钱的总和最大。

Solution

每个位置最终一定都是某一个$c$, 否则我们就可以把她调整到最近的$c$, 使答案不会变差. 这样我们就可以离散化了, 只要记录每个位置是哪一个$c$即可.

考虑枚举区间$l, r$, 并枚举区间最小值k$和出现位置$$i$(跑得过). 这里还是会出现一点分治时常见的套路, 计算跨越当前位置的区间. 这些区间在当前枚举的$l, r$下都会选择这个位置作为洗车点. 也就是$f[l][i - 1][k] + f[i + 1][r][k] + cnt[i][k] \times val[k]$, $cnt$统计了当前范围内跨越$i$并且$c \ge k$的数量. 用这个式子更新答案. 最终$f$的意义是满足$l, r$且$c$均大于等于$k$的答案, 所以还要用$f[l][r][k+1]$来更新一次答案. 记录下每个区间的决策点和决策点的状态.

第一问的答案, 根据方程意义就是$f[1][n][1]$. 有决策点和决策状态我们就可以递归地构造出整个数组的答案.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

const int maxn = 55;
const int maxm = 4005;

int n, m, val[maxm], cnt[maxn][maxm];
int f[maxn][maxn][maxm], pos[maxn][maxn][maxm], mp[maxn][maxn][maxm];
int a[maxm], b[maxm], c[maxm], cpy[maxm], tot;
int ans[maxn];

void calc(int l, int r, int v) {
if (l > r) return;
int p = pos[l][r][v], rk = mp[l][r][v];//当前位置是区间最小值v, 区间其他位置都大于等于v, 可以递归执行.
ans[p] = cpy[rk];
calc(l, p - 1, rk);
calc(p + 1, r, rk);
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; ++i)
scanf("%d%d%d", &a[i], &b[i], &c[i]), cpy[i] = c[i];
sort(cpy + 1, cpy + 1 + m);
tot = unique(cpy + 1, cpy + 1 + m) - (cpy + 1);
for (int i = 1; i <= m; ++i)
c[i] = lower_bound(cpy + 1, cpy + 1 + tot, c[i]) - cpy;
for (int len = 1; len <= n; ++len)
for (int l = 1; l <= n - len + 1; ++l) {
int r = l + len - 1;
for (int i = l; i <= r; ++i)
for (int j = 1; j <= tot; ++j)
cnt[i][j] = 0;
for (int i = 1; i <= m; ++i)
if (a[i] >= l && b[i] <= r)
for (int j = a[i]; j <= b[i]; ++j)
cnt[j][c[i]]++;
for (int i = l; i <= r; ++i)
for (int j = tot; j; --j)
cnt[i][j] += cnt[i][j + 1];
for (int j = tot; j; --j) {
pos[l][r][j] = l, mp[l][r][j] = j;
for (int i = l; i <= r; ++i) {
int tmp = f[l][i - 1][j] + f[i + 1][r][j] + cnt[i][j] * cpy[j];
if (tmp > f[l][r][j])
f[l][r][j] = tmp, pos[l][r][j] = i;
}
if (f[l][r][j + 1] > f[l][r][j]) f[l][r][j] = f[l][r][j + 1], pos[l][r][j] = pos[l][r][j + 1], mp[l][r][j] = mp[l][r][j + 1];
}
}
printf("%d\n", f[1][n][1]);
calc(1, n, 1);
for (int i = 1; i <= n; ++i)
printf("%d ", ans[i]);
return 0;
}