「CTSC2018」暴力写挂

题目链接:https://loj.ac/problem/2553

题目背景真有意思

这题之前我从来没有写过边分治。。。。

把式子变一下:

可以发现第一棵树可以搞成无根树了,那么我们对第一棵树边分治。

在分治树上每个节点都是一条边,对每条边记两个东西表示左边(或右边)的${\rm depth}(x)+{\rm distance}(x,a)$的最大值,$a$是这条边任意一个点,那么两边最大值一加就是经过这条边的式子前半部分最大值的两倍。

我们考虑枚举式子最后的$\rm LCA$,那么我们可以这样做:我们想办法得到以每个点为子树的边分树,换句话说是在边分树结构上激活子树里的这些点,其他的全搞成负无穷。

注意到边分树是一棵二叉树,换句话说这玩意和线段树的结构相同,那么我们可以直接把线段树合并那一套抄过来。

所以做法是这样的:

  • 我们遍历第二棵树,假设当前遍历到了$x$,那么我们统计以$x$为$\rm LCA$的答案。
  • 我们先在$x$为根的边分树中激活$x$。
  • 首先递归求出所有儿子的边分树,然后依次合并,合并的过程中统计出一个点在$x$点边分树,一个点在儿子节点的边分树上的答案,这样两个点$\rm LCA$一定是$x$,并且不会漏掉一些情况。
  • 最后更新一下两个点都是$x$的答案,这个情况上面算不到。

根据线段树合并可以知道这玩意的复杂度等于建出来的点数,也就是$O(n\log n)$。

其实代码还挺好写的。

注意边分治基础操作,分治前要把树搞成二叉的,不然会被菊花图卡成$O(n^2)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include<bits/stdc++.h>
using namespace std;

void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}

void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}

#define lf double
#define ll long long

#define pii pair<int,int >
#define vec vector<int >

#define pb push_back
#define mp make_pair
#define fr first
#define sc second

#define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++)

const int maxn = 3.67e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 1e9+7;

int n,rt[maxn];
ll ans=-1e18;

namespace Tree {
int cnt,head[maxn<<1],tot=1,ban[maxn<<2],sz[maxn<<1],size,ed,xx,yy,ccc,now[maxn],ww,dir[maxn];
int ls[maxn<<5],rs[maxn<<5];
ll lmx[maxn<<5],rmx[maxn<<5],dep[maxn<<1];
struct edge{int to,nxt,w;}e[maxn<<2];
vector<pii > r[maxn];

void add(int u,int v,int w) {e[++tot]=(edge){v,head[u],w},head[u]=tot;}
void ins(int u,int v,int w) {
// printf("ins :: %d %d %d\n",u,v,w);fflush(stdout);
add(u,v,w),add(v,u,w);
}

void inss(int u,int v,int w) {r[u].pb(mp(v,w)),r[v].pb(mp(u,w));}

void build(int x,int fa) {
int pre=0;
// printf("build :: %d %d\n",x,fa);cerr<<"OK"<<endl;
for(int i=0;i<(int)r[x].size();i++) {
int v=r[x][i].fr,w=r[x][i].sc;
if(v==fa) continue;
if(!pre) ins(x,v,w),pre=x;
else ins(pre,++cnt,0),ins(cnt,v,w),dep[cnt]=dep[x],pre=cnt;
dep[v]=dep[x]+w;build(v,x);
}
}

void get_rt(int x,int fa) {
sz[x]=1;
for(int v,i=head[x];i;i=e[i].nxt) {
if(ban[i]||(v=e[i].to)==fa) continue;
get_rt(v,x),sz[x]+=sz[v];
int p=max(sz[v],size-sz[v]);
if(p<ww) {ww=p,ed=i,xx=x,yy=v;}
}
}

void dfs(int x,int fa,ll d,int c) {
if(x<=n) { //注意这里和上面说的有点不一样
++ccc; //因为激活第一个点的时候一定会建出一条链,所以直接在这里遍历这个块的所有点的时候在下面挂一个点也是一样。
if(!rt[x]) rt[x]=ccc;
else (dir[x]?rs:ls)[now[x]]=ccc;
now[x]=ccc,dir[x]=c;
(dir[x]?rmx:lmx)[ccc]=d+dep[x];
(dir[x]?lmx:rmx)[ccc]=-1e18;
}
for(int i=head[x];i;i=e[i].nxt)
if(!ban[i]&&e[i].to!=fa) dfs(e[i].to,x,d+e[i].w,c);
}

void solve(int x,int ss) {
// printf("solve :: %d %d\n",x,ss);
if(ss==1) return ;
size=ss,ww=1e9;get_rt(x,0);ban[ed]=ban[ed^1]=1;
// cerr<<xx<<' '<<yy<<endl;;
dfs(xx,0,0,0);dfs(yy,0,e[ed].w,1);int tmp=sz[yy],tt=yy;
solve(xx,size-sz[yy]),solve(tt,tmp);
}

int merge(int x,int y,ll d) {
if(!x||!y) return x+y;
ans=max(ans,(lmx[x]+rmx[y])/2-d);
ans=max(ans,(lmx[y]+rmx[x])/2-d);
lmx[x]=max(lmx[x],lmx[y]);
rmx[x]=max(rmx[x],rmx[y]);
ls[x]=merge(ls[x],ls[y],d);
rs[x]=merge(rs[x],rs[y],d);return x;
}

// void debug(int x) {
// printf("now :: %d %lld %lld\n",x,lmx[x],rmx[x]);
// if(ls[x]) printf("ls :: %d\n",ls[x]),debug(ls[x]);
// if(rs[x]) printf("rs :: %d\n",rs[x]),debug(rs[x]);
// }
}

int head[maxn],tot;
struct edge{int to,nxt,w;}e[maxn<<1];

void add(int u,int v,int w) {e[++tot]=(edge){v,head[u],w},head[u]=tot;}
void ins(int u,int v,int w) {add(u,v,w),add(v,u,w);}

void dfs(int x,int fa,ll d) {
ans=max(ans,Tree::dep[x]-d);
// printf("dfs :: %d %d %lld\n",x,fa,d);
for(int i=head[x];i;i=e[i].nxt) {
if(e[i].to==fa) continue;
dfs(e[i].to,x,d+e[i].w);
rt[x]=Tree::merge(rt[x],rt[e[i].to],d);
}
}

int main() {
read(n);Tree::cnt=n;
for(int i=1,x,y,z;i<n;i++) read(x),read(y),read(z),Tree::inss(x,y,z);
for(int i=1,x,y,z;i<n;i++) read(x),read(y),read(z),ins(x,y,z);
Tree::build(1,0);//cerr<<"OK"<<endl;
Tree::solve(1,Tree::cnt);
// for(int i=1;i<=n;i++) {
// printf("debug :: %d\n",i);
// Tree::debug(rt[i]);
// }
dfs(1,0,0);
printf("%lld\n",ans);
return 0;
}