树套树 + 卡空间.

Table of Contents

  1. Description
  2. Solution

Description

给定一个长度为$n$序列, 每个位置上有一个编号$id_i$和一个权$v_i$, 对于一个逆序对(按编号和位置), 贡献是权的和.

有$m$次交换操作, 计算每次操作后的总贡献.

$n, m \le 50000$.

Solution

也就是维护一个带权的逆序对.

50000的数据范围可以跑$log^2$, 也就是树套树了.

树状数组维护区间, 主席树维护两个和, 分别是权值和出现次数.

对于每次交换操作$(x, y)$, 假设$x < y$, 有:

交换前$[x+1, y-1]$中大于$id_x$的位置将产生贡献, 小于的位置贡献将消失.

$y$反之.

最后考虑$x,y$是否会互相产生影响, 加入或去掉贡献, 思路还是很清晰的.

最近写树套树经常遇到空间问题, 对于这种不会回滚的问题, 我们直接修改版本就好了, 没必要可持久化, 可以省下一个$log$.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

const int maxn = 5e+4 + 5;
const int maxp = maxn * 200;
const int mod = 1e+9 + 7;

int n, m;
int lp[maxn], rp[maxn];
int ls[maxp], rs[maxp], su[maxp], tu[maxp];
int id[maxn], cnt;
int val[maxn], a[maxn], tot;
int ans;

void update(int &lst, int l, int r, int p, int v1, int v2) {
if (!lst) lst = ++cnt;
tu[lst] = (tu[lst] + v1) % mod;
su[lst] = (su[lst] + v2) % mod;
if (l == r) return;
int mid = (l + r) >> 1;
if (p <= mid) update(ls[lst], l, mid, p, v1, v2);
else update(rs[lst], mid + 1, r, p, v1, v2);
}
void query(int cur, int l, int r, int L, int R, int &ans1, int &ans2) {
if (!cur) return;
if (L <= l && r <= R) {
ans1 = (ans1 + tu[cur]) % mod;
ans2 = (ans2 + su[cur]) % mod;
return;
}
int mid = (l + r) >> 1;
if (L <= mid) query(ls[cur], l, mid, L, R, ans1, ans2);
if (R > mid) query(rs[cur], mid + 1, r, L, R, ans1, ans2);
}
void modify(int x, int v, int v1, int v2) {
for (; x <= n; x += (x & -x)) update(id[x], 1, n, v, v1, v2);
}
void getAns(int l, int r, int L, int R, int &ans1, int &ans2) {
if (L > R) return;
int ret1 = 0, ret2 = 0;
l = l - 1;
for (; l; l -= (l & -l)) query(id[l], 1, n, L, R, ret1, ret2);
for (; r; r -= (r & -r)) query(id[r], 1, n, L, R, ans1, ans2);
ans1 = (ans1 - ret1 + mod) % mod;
ans2 = (ans2 - ret2 + mod) % mod;
}
inline int rd() {
register int x = 0, c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
return x;
}

int main() {
n = rd(); m = rd();
for (int i = 1; i <= n; ++i) {
int valSum = 0, timSum = 0;
a[i] = rd(); val[i] = rd();
modify(i, a[i], val[i], 1);
getAns(1, i - 1, a[i] + 1, n, valSum, timSum);
ans = (ans + valSum) % mod;
ans = (ans + 1ll * timSum * val[i] % mod) % mod;
}
while (m--) {
int x = rd(), y = rd();
int valSum = 0, timSum = 0, valMinus = 0, timMinus = 0;
if (x > y) swap(x, y);
int idx = a[x], idy = a[y];
// x:
getAns(x + 1, y - 1, idx + 1, n, valSum, timSum);
getAns(x + 1, y - 1, 1, idx - 1, valMinus, timMinus);
valSum = (valSum - valMinus + mod) % mod;
timSum = (timSum - timMinus + mod) % mod;
ans = (ans + valSum) % mod;
ans = (ans + 1ll * timSum * val[x]) % mod;
// y:
valSum = timSum = valMinus = timMinus = 0;
getAns(x + 1, y - 1, 1, idy - 1, valSum, timSum);
getAns(x + 1, y - 1, idy + 1, n, valMinus, timMinus);
valSum = (valSum - valMinus + mod) % mod;
timSum = (timSum - timMinus + mod) % mod;
ans = (ans + valSum) % mod;
ans = (ans + 1ll * timSum * val[y]) % mod;
if (idx < idy) ans = (ans + val[x] + val[y]) % mod;
else if (idx > idy) ans = (ans - (val[x] + val[y]) + mod) % mod;
modify(x, idx, mod - val[x], mod - 1);
modify(y, idx, val[x], 1);
modify(y, idy, mod - val[y], mod - 1);
modify(x, idy, val[y], 1);
swap(a[x], a[y]); swap(val[x], val[y]);
printf("%d\n", ans);
}
return 0;
}