线段树。

Table of Contents

  1. Description
  2. Solution

Description

共有$m$部电影,编号为$1 - m$,第$i$部电影的好看值为$w[i]$。

在$n$天之中(从$1 - n$编号)每天会放映一部电影,第$i$天放映的是第$f[i]$部。

你可以选择$l,r(1 \le l \le r \le n)$,并观看第$l,l+1,\ldots,r$天内所有的电影。如果同一部电影你观看多于一次,你会感到无聊,于是无法获得这部电影的好看值。所以你希望最大化观看且仅观看过一次的电影的好看值的总和。

Solution

类似采花(HEOI2012)一样的操作.

预处理每个位置的颜色下一次出现的位置.

枚举每个左端点, 线段树上维护该点作为右端点时的答案. 当离开一个点时, 从 $i$ 到 $nxt[i] - 1$ 的答案会减少$w[f[i]]$. 而从 $nxt[i]$ 到$nxt[nxt[i]] - 1$的位置答案会增加相同的值, 我们用线段树支持区间加和区间查询最大值就好了.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

typedef long long ll;

const int maxn = 1e+6 + 5;

int n, m, f[maxn], w[maxn];
int lst[maxn], nxt[maxn];
ll mx[maxn << 2], tag[maxn << 2], ans;

inline int rd() {
register int x = 0, c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
return x;
}
inline void pushdown(int cur) {
if (tag[cur]) {
//s[cur << 1] += tag[cur] * (len >> (len >> 1));
//s[cur << 1|1] += tag[cur] * (len >> 1);
tag[cur << 1] += tag[cur];
tag[cur << 1|1] += tag[cur];
mx[cur << 1] += tag[cur];
mx[cur << 1|1] += tag[cur];
tag[cur] = 0;
}
}
inline void pushup(int cur) {
//s[cur] = s[cur << 1] + s[cur << 1|1];
mx[cur] = max(mx[cur << 1], mx[cur << 1|1]);
}
void update(int cur, int l, int r, int L, int R, int c) {
if (L <= l && r <= R) {
//s[cur] += c * (r - l + 1);
tag[cur] += c;
mx[cur] += c;
return;
}
int mid = (l + r) >> 1; pushdown(cur);
if (L <= mid) update(cur << 1, l, mid, L, R, c);
if (R > mid) update(cur << 1|1, mid + 1, r, L, R, c);
pushup(cur);
}

int main() {
n = rd(); m = rd();
for (int i = 1; i <= n; ++i) f[i] = rd();
for (int i = 1; i <= m; ++i) w[i] = rd(), lst[i] = n + 1;
for (int i = n; i; --i) {
nxt[i] = lst[f[i]];
lst[f[i]] = i;
}
for (int i = 1; i <= m; ++i) {
if (lst[i] != n + 1) {
if (nxt[lst[i]] == n + 1) update(1, 1, n, lst[i], n, w[i]);
else update(1, 1, n, lst[i], nxt[lst[i]] - 1, w[i]);
}
}
for (int i = 1; i <= n; ++i) {
ans = max(ans, mx[1]);
if (nxt[i] != n + 1) {
update(1, 1, n, i, nxt[i] - 1, -w[f[i]]);
if (nxt[nxt[i]] != n + 1) update(1, 1, n, nxt[i], nxt[nxt[i]] - 1, w[f[i]]);
else update(1, 1, n, nxt[i], n, w[f[i]]);
} else update(1, 1, n, i, n, -w[f[i]]);
}
printf("%lld\n", ans);
return 0;
}