单调栈,前缀和,RMQ。

Table of Contents

  1. Description
  2. Solution

Description

给定长度为$n$的序列:$a_1,a_2,\ldots ,a_n$,记为$a[1:n]$。类似地,$a[l:r](1\le l \le r \le N)$是指序列:$a_l,a_{l+1},\ldots , a_{r-1},a_r$。若$1 \le l \le s \le t \le r \le n$,则称$a[s:t]$是$a[l:r]$的子序列。

现在有$q$个询问,每个询问给定两个数$l$和$r$,$1 \le l \le r \le n$,求$a[l:r]$的不同子序列的最小值之和。

例如,给定序列$5,2,4,1,3$,询问给定的两个数为$1和3$,那么$a[1:3]$有
6个子序列$a[1:1],a[2:2],a[3:3],a[1:2],a[2:3],a[1:3]$,这6个子序列的最小值之和为$5+2+4+2+2+2=17$。

Solution

首先有一个很常见的套路,我们不能直接统计就算贡献。

对于一个区间中的最小值 $a_p$, 所有跨越她的区间有$(r - p + 1) \times (p - l + 1)$个, 可以加入贡献, 因此我们接着要计算的就是区间$[l, p)$和$(p, r]$的答案.

但这样做下去单次最好也要$O(n)$.

我们定义一个$F[l][r]$, 为右端点在$r$, 左端点在$[l,r]$的所有答案, 对于右侧的情况, 我们需要得到的就是$F[p + 1][r]$. 首先考虑$F$的递推关系, 定义$pre_r$为$r$之前第一个小于她的位置, 那么在这之后$a_r$都是区间内最小值, 也就是: $F[l][r] = F[l][pre_r] + a_r \times (r - pre_r)$.

这是可以递推的! 但由于状态量是$n^2$的, 我们还需要做一些优化.

考虑丢掉$l$, 定义一个状态$f_p$, 归纳可得:

由于$p$是区间内最小值, 所以向前推若干位时, 一定会存在一个$x$, 使得$pre_x = p$, 那么对于一个$r$, 就有$f_r = a_r \times (r - pre_r) + a_{pre_r} \times(\ldots)+\ldots + a_x \times (x - p) + f_p$.

这时我们有$f_r - f_p$为右端点在$r$, 左端点在$(p, r]$的答案.

那么我们就可以得出每个右端点$r_i \in (p,r]$的答案了, 利用前缀和优化, 有:

$ \begin{align}ans = \sum_{i = p + 1}^{r}{f_i - f_p} = \sum_{i=p+1}^{r}{f_i} - {f_p \times (r - p)} = sum_r - sum_p - f_p\times (r - p)\end{align}$.

左侧的答案同理,反向求一遍即可。

单调栈可以$O(n)$求出每个位置的$pre$和$suf$, 每次询问用ST表查出最小值位置,前后两端求一下答案就好了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

typedef long long ll;
const int maxn = 1e+5 + 5;

int n, q;
int a[maxn], pre[maxn], suf[maxn];
int stk[maxn], top;
int ST[maxn][18], LO2[maxn], PO2[20];
ll ftr[maxn], ftl[maxn], sfr[maxn], sfl[maxn];

inline int rd() {
register int x = 0, f = 0, c = getchar();
while (!isdigit(c)) {
if (c == '-') f = 1;
c = getchar();
}
while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar();
return f ? -x : x;
}
inline int query(int l, int r) {
int delt = LO2[r - l + 1];
return (a[ST[l][delt]] < a[ST[r - PO2[delt] + 1][delt]] ?
ST[l][delt] :
ST[r - PO2[delt] + 1][delt]);
}

int main() {
n = rd(); q = rd();
LO2[1] = 0;
for (int i = 2; i <= n; ++i) LO2[i] = (LO2[i >> 1] + 1);
PO2[0] = 1;
for (int i = 1; i < 18; ++i) PO2[i] = (PO2[i - 1] << 1);
for (int i = 1; i <= n; ++i) {
a[i] = rd(); ST[i][0] = i;
}
for (int j = 1; j <= LO2[n]; ++j)
for (int i = 1; i <= n - PO2[j - 1] + 1; ++i)
ST[i][j] = (a[ST[i][j - 1]] < a[ST[i + PO2[j - 1]][j - 1]] ? ST[i][j - 1] : ST[i + PO2[j - 1]][j - 1]);

a[0] = a[n + 1] = 0x3f3f3f3f;
for (int i = 1; i <= n; ++i) {
while (top && a[stk[top]] > a[i]) suf[stk[top]] = i, top--;
pre[i] = stk[top]; stk[++top] = i;
}
while (top) pre[stk[top]] = stk[top - 1], suf[stk[top]] = n + 1, top--;
for (int i = 1; i <= n; ++i)
ftr[i] = ftr[pre[i]] + (ll)a[i] * (i - pre[i]), sfr[i] = sfr[i - 1] + ftr[i];
for (int i = n; i; --i)
ftl[i] = ftl[suf[i]] + (ll)a[i] * (suf[i] - i), sfl[i] = sfl[i + 1] + ftl[i];
int l, r, pos;
while (q--) {
l = rd(); r = rd(); pos = query(l, r);
ll ans = 1ll * a[pos] * (pos - l + 1) * (r - pos + 1) +
(ll)sfr[r] - (ll)sfr[pos] - 1ll * ftr[pos] * (r - pos) +
(ll)sfl[l] - (ll)sfl[pos] - 1ll * ftl[pos] * (pos - l);
printf("%lld\n", ans);
}
return 0;
}

然后听说有一道类似的题BZOJ4262