Luogu P2664 tree game (divide and conquer)

meaning of the title

Title Link

Sol

Immortal title.. Orz yyb

Considering the point divide and conquer, each time we only need to count the contribution between point pairs whose current point is \ (LCA \) and the contribution from \ (LCA \) to all points.

One of the most magical ideas is that for the color on the path of any two point pairs, we only count the contribution of Reagan's most recent point.

With this idea, we can muddle through. The specific details are tedious, but the idea is probably the contribution of the fact to maintain the points in the subtree of each point. For example, if the color of a point appears for the first time on its path to the root, then all points \ (siz[x] \) in its subtree will contribute to the points outside.

When we count subtrees, we only need to eliminate the influence of subtrees first, and then consider the contribution of new colors when we use dfs..

Complexity \ (O(n \log n) \)

#include<bits/stdc++.h> 
#define Pair pair<int, int>
#define MP(x, y) make_pair(x, y)
#define fi first
#define se second
#define LL long long 
#define ull unsigned long long 
#define Fin(x) {freopen(#x".in","r",stdin);}
#define Fout(x) {freopen(#x".out","w",stdout);}
#define pb push_back 
using namespace std;
const int MAXN = 1e6 + 10, mod = 1e9 + 7, INF = 1e9 + 10;
const double eps = 1e-9;
template <typename A, typename B> inline bool chmin(A &a, B b){if(a > b) {a = b; return 1;} return 0;}
template <typename A, typename B> inline bool chmax(A &a, B b){if(a < b) {a = b; return 1;} return 0;}
template <typename A, typename B> inline LL add(A x, B y) {if(x + y < 0) return x + y + mod; return x + y >= mod ? x + y - mod : x + y;}
template <typename A, typename B> inline void add2(A &x, B y) {if(x + y < 0) x = x + y + mod; else x = (x + y >= mod ? x + y - mod : x + y);}
template <typename A, typename B> inline LL mul(A x, B y) {return 1ll * x * y % mod;}
template <typename A, typename B> inline void mul2(A &x, B y) {x = (1ll * x * y % mod + mod) % mod;}
template <typename A> inline void debug(A a){cout << a << '\n';}
template <typename A> inline LL sqr(A x){return 1ll * x * x;}
template <typename A, typename B> inline LL fp(A a, B p, int md = mod) {int b = 1;while(p) {if(p & 1) b = mul(b, a);a = mul(a, a); p >>= 1;}return b;}
template <typename A> A inv(A x) {return fp(x, mod - 2);}
inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}
int N, c[MAXN], cnt[MAXN], vis[MAXN], siz[MAXN], Lim, mx[MAXN], root;
LL ans[MAXN], num[MAXN], Sum;
vector<int> v[MAXN];
void FindRoot(int x, int fa) {
    siz[x] = 1; mx[x] = 1;
    for(auto &to : v[x]) {
        if(to == fa || vis[to]) continue;
        FindRoot(to, x);
        siz[x] += siz[to];
        chmax(mx[x], siz[to]);
    }
    chmax(mx[x], Lim - siz[x]);
    if(mx[x] < mx[root]) 
        root = x;
}

void dfs(int x, int fa, int opt) {
    cnt[c[x]]++;
    if(cnt[c[x]] == 1) Sum += siz[x] * opt, num[c[x]] += siz[x] * opt;
    for(auto &to : v[x])
        if(to != fa && !vis[to]) dfs(to, x, opt);
    cnt[c[x]]--;
}
void calc(int x, int fa) {
    cnt[c[x]]++;
    if(cnt[c[x]] == 1) Sum += Lim - num[c[x]];
    ans[x] += Sum;
    for(auto &to : v[x]) {
        if(to == fa || vis[to]) continue;
        calc(to, x);
    }
    cnt[c[x]]--;
    if(cnt[c[x]] == 0) Sum -= Lim - num[c[x]];
}
void Divide(int x) {
    if(vis[x]) return ; vis[x] = 1;
    Sum = 0; FindRoot(x, 0);
    dfs(x, 0, 1); ans[x] += Sum;
    for(auto &to : v[x]) {
        if(vis[to]) continue;
        num[c[x]] -= siz[to]; Sum -= siz[to]; Lim -= siz[to];
        cnt[c[x]] = 1; dfs(to, x, -1); cnt[c[x]] = 0;
        calc(to, x);
        cnt[c[x]] = 1; dfs(to, x, 1); cnt[c[x]] = 0;
        num[c[x]] += siz[to]; Sum += siz[to]; Lim += siz[to];
    }
    dfs(x, 0, -1);
    for(auto &to : v[x]) 
        if(!vis[to]) {
            root = 0, Lim = siz[to], FindRoot(to, x);
            Divide(root);
    }
}
signed main() {
    //freopen("a.in", "r", stdin);freopen("b.out", "w", stdout);
    N = read(); mx[0] = 1e9;
    for(int i = 1; i <= N; i++) c[i] = read();
    for(int i = 1; i < N ; i++) {
        int x = read(), y = read();
        v[x].pb(y); v[y].pb(x);
    }
    Lim = N; root = 0; FindRoot(1, 0);
    Divide(root);
    for(int i = 1; i <= N; i++) cout << ans[i] << '\n';
    return 0;
}

Tags: C++

Posted on Sat, 30 Nov 2019 23:54:42 -0800 by nonlinear