1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| #include <iostream> #include <algorithm> #include <cstring> #include <cstdio>
using namespace std;
typedef long long ll;
const int maxn = 1e+6 + 5;
int n, m, f[maxn], w[maxn]; int lst[maxn], nxt[maxn]; ll mx[maxn << 2], tag[maxn << 2], ans;
inline int rd() { register int x = 0, c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar(); return x; } inline void pushdown(int cur) { if (tag[cur]) { tag[cur << 1] += tag[cur]; tag[cur << 1|1] += tag[cur]; mx[cur << 1] += tag[cur]; mx[cur << 1|1] += tag[cur]; tag[cur] = 0; } } inline void pushup(int cur) { mx[cur] = max(mx[cur << 1], mx[cur << 1|1]); } void update(int cur, int l, int r, int L, int R, int c) { if (L <= l && r <= R) { tag[cur] += c; mx[cur] += c; return; } int mid = (l + r) >> 1; pushdown(cur); if (L <= mid) update(cur << 1, l, mid, L, R, c); if (R > mid) update(cur << 1|1, mid + 1, r, L, R, c); pushup(cur); }
int main() { n = rd(); m = rd(); for (int i = 1; i <= n; ++i) f[i] = rd(); for (int i = 1; i <= m; ++i) w[i] = rd(), lst[i] = n + 1; for (int i = n; i; --i) { nxt[i] = lst[f[i]]; lst[f[i]] = i; } for (int i = 1; i <= m; ++i) { if (lst[i] != n + 1) { if (nxt[lst[i]] == n + 1) update(1, 1, n, lst[i], n, w[i]); else update(1, 1, n, lst[i], nxt[lst[i]] - 1, w[i]); } } for (int i = 1; i <= n; ++i) { ans = max(ans, mx[1]); if (nxt[i] != n + 1) { update(1, 1, n, i, nxt[i] - 1, -w[f[i]]); if (nxt[nxt[i]] != n + 1) update(1, 1, n, nxt[i], nxt[nxt[i]] - 1, w[f[i]]); else update(1, 1, n, nxt[i], n, w[f[i]]); } else update(1, 1, n, i, n, -w[f[i]]); } printf("%lld\n", ans); return 0; }
|