1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
| #include <bits/stdc++.h>
using namespace std; typedef long long ll;
const int maxn = 1e+5 + 5; const int maxm = 5e+5 + 5;
int bel[maxn]; struct Q { int l, r, tg, id; inline bool operator<(const Q &rhs) const { return bel[l] == bel[rhs.l] ? r < rhs.r : bel[l] < bel[rhs.l]; } }q[maxm * 9]; struct edge { int to, nxt; }e[maxn << 1]; int n, m; int a[maxn], ptr, lnk[maxn], dfn[maxn], rev[maxn], idx; int cpy[maxn], cnt, out[maxn], dep[maxn], lo2[maxn]; int F[20][maxn], rt, qs, blo; int qu[4], qv[4], typu[4], typv[4]; ll cur, ans[maxm], bkta[maxn], bktb[maxn];
inline int rd() { register int x = 0, c = getchar(); while (!isdigit(c)) c = getchar(); while (isdigit(c)) x = x * 10 + (c ^ 48), c = getchar(); return x; } inline void add(const int &bgn, const int &end) { e[++ptr] = (edge){end, lnk[bgn]}; lnk[bgn] = ptr; } void dfs(int x, int fa) { dep[x] = dep[fa] + 1; F[0][x] = fa; dfn[x] = ++idx; rev[idx] = a[x]; for (int i = 1; i <= 17; ++i) F[i][x] = F[i - 1][F[i - 1][x]]; for (int p = lnk[x]; p; p = e[p].nxt) { int y = e[p].to; if (y == fa) continue; dfs(y, x); } out[x] = idx; } inline int jmp(int x, int y) { for (int i = lo2[dep[y] - dep[x]]; ~i; --i) { if (dep[y] - dep[x] > (1 << i)) y = F[i][y]; } return y; }
int main() { n = rd(); m = rd(); for (int i = 1; i <= n; ++i) cpy[i] = a[i] = rd(); lo2[1] = 0; for (int i = 2; i <= n; ++i) lo2[i] = lo2[i >> 1] + 1; sort(cpy + 1, cpy + 1 + n); cnt = unique(cpy + 1, cpy + 1 + n) - (cpy + 1); for (int i = 1; i <= n; ++i) a[i] = lower_bound(cpy + 1, cpy + 1 + cnt, a[i]) - cpy; int u, v; for (int i = 1; i < n; ++i) { u = rd(); v = rd(); add(u, v); add(v, u); } dfs(1, 0); rt = 1; int opt, M = 0; while (m--) { opt = rd(); if (opt == 1) { rt = rd(); } else { ++M; u = rd(); v = rd(); int ptru = 0, t; if (u == rt) typu[ptru] = 1, qu[ptru++] = n; else if (dfn[rt] < dfn[u] || dfn[rt] > out[u]) typu[ptru] = 1, qu[ptru++] = out[u], typu[ptru] = -1, qu[ptru++] = dfn[u] - 1; else { t = jmp(u, rt); typu[ptru] = 1, qu[ptru++] = n; typu[ptru] = -1, qu[ptru++] = out[t]; typu[ptru] = 1, qu[ptru++] = dfn[t] - 1; } int ptrv = 0; if (v == rt) typv[ptrv] = 1, qv[ptrv++] = n; else if (dfn[rt] < dfn[v] || dfn[rt] > out[v]) { typv[ptrv] = 1, qv[ptrv++] = out[v]; typv[ptrv] = -1, qv[ptrv++] = dfn[v] - 1; } else { t = jmp(v, rt); typv[ptrv] = 1, qv[ptrv++] = n; typv[ptrv] = -1, qv[ptrv++] = out[t]; typv[ptrv] = 1, qv[ptrv++] = dfn[t] - 1; } for (int i = 0; i < ptru; ++i) for (int j = 0; j < ptrv; ++j) if (qu[i] && qv[j]) q[++qs] = (Q){qu[i], qv[j], typu[i] * typv[j], M}; } } blo = sqrt(n); for (int i = 1; i <= n; ++i) bel[i] = (i - 1) / blo + 1; sort(q + 1, q + 1 + qs); int L = 0, R = 0; for (int i = 1; i <= qs; ++i) { while (L < q[i].l) cur += bktb[rev[++L]], bkta[rev[L]]++; while (R < q[i].r) cur += bkta[rev[++R]], bktb[rev[R]]++; while (L > q[i].l) cur -= bktb[rev[L]], bkta[rev[L--]]--; while (R > q[i].r) cur -= bkta[rev[R]], bktb[rev[R--]]--; ans[q[i].id] += q[i].tg * cur; } for (int i = 1; i <= M; ++i) printf("%lld\n", ans[i]); return 0; }
|