[LOJ_36] [Logu P3258] [LCA] [Tree Difference]

Main idea of the title:

Topic link:
Luo Gu: https://www.luogu.org/problem/P3258
LOJ: https://loj.ac/problem/2236
Give the sequence of a tree and nnn points, and find that each point will be passed several times. The last time the rule is reached is not counted.

Train of thought:

This is a topic in "Provincial Selection of Beast Fighting Ground-Tree Chain Dividing".
Brush with the mentality of backing trees and cutting boards. After reading the question
Isn't this a difference sb problem on a tree?????
Since it is in tree dissection classification, let's use tree dissection to find LCA.
In general tree dissection, we will have such a procedure.

void addrange(int x,int y,int k)
{
	while (top[x]!=top[y])
	{
		if (dep[top[x]]<dep[top[y]]) swap(x,y);
		Tree.update(1,id[top[x]],id[x],k);
		x=fa[top[x]];
	}
	if (id[x]>id[y]) Tree.update(1,id[y],id[x],k);
		else Tree.update(1,id[x],id[y],k);
}

We find that when we finally exit the while while loop, the two points of xyxyxy xy must be in the same heavy chain.
So obviously the LCA at this point is a shallow point.
In this way, LCA can be obtained by tree dissection O(log n) O(log n) O (logn). And the constant is very small.
Then just use the tree difference to do it.
But it should be noted that from x y,y ZX to y,y to ZX y,y z, we will calculate YY twice, which leads to more than one answer. So the final answer is minus 1.
At the same time, the first, n1, n1, n points can only be counted once, which is not supposed to be minus 1, but the title requires that the last time to reach the nnn point is not counted, so it still needs to be minus 1, and the first point is not needed.
Time complexity O (n log n) O (n log n) O (nlogn)

Code:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N=300010;
int n,tot,a[N],s[N],head[N],dep[N],son[N],fa[N],top[N],size[N];

struct edge
{
	int next,to;
}e[N*2];

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

void dfs1(int x,int f)
{
	fa[x]=f; dep[x]=dep[f]+1; size[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=f)
		{
			dfs1(y,x);
			if (size[y]>size[son[x]]) son[x]=y;
			size[x]+=size[y];
		}
	}
}

void dfs2(int x,int tp)
{
	top[x]=tp;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];~i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=fa[x] && y!=son[x]) dfs2(y,y);
	}
}

void dfs3(int x)
{
	for (int i=head[x];~i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=fa[x])
		{
			dfs3(y);
			s[x]+=s[y];
		}
	}
}

int lca(int x,int y)
{
	while (top[x]!=top[y])
	{
		if (dep[top[x]]<dep[top[y]]) swap(x,y);
		x=fa[top[x]];
	}
	return dep[x]>dep[y]?y:x;
}

int main()
{ 
	memset(head,-1,sizeof(head));
	scanf("%d",&n);
	for (int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	dfs1(1,0); dfs2(1,1);
	for (int i=1;i<n;i++)
	{
		int LCA=lca(a[i],a[i+1]);
		s[a[i]]++; s[a[i+1]]++;
		s[LCA]--; s[fa[LCA]]--;
	}
	dfs3(1); s[a[1]]++;  //The first point does not need to be subtracted by 1
	for (int i=1;i<=n;i++)
		printf("%d\n",s[i]-1);
	return 0;
}

Posted on Tue, 08 Oct 2019 05:34:24 -0700 by tyr_82