题干
原题链接
给定一棵 n n n 个节点的树,节点编号为 1 ∼ n 1∼n 1∼n。每个节点都被染成了黑色(用 1 1 1 表示)或白色(用 0 0 0 表示)。从黑色节点无法到达白色节点,反之亦然。因此,两个同色节点相互可达的前提是,两个同色节点之间的路径中不含另一种颜色的节点。
我们希望将树中的所有节点都染成同一种颜色(全黑或全白均可)。为此,你可以采用我们指定的染色操作。每次操作可以选择一个节点 v v v,并改变节点 v v v 以及其所有可达同色节点的颜色(黑变白、白变黑)。
例如,在下图中,点 1 1 1 和点 2 , 3 , 8 , 9 2,3,8,9 2,3,8,9 之间相互可达,但是点 1 1 1 和点 6 6 6 之间相互不可达(被点 5 5 5 挡住了),因此,如果选择点 1 1 1 进行染色操作,会将点 1 , 2 , 3 , 8 , 9 1,2,3,8,9 1,2,3,8,9 全部染黑。
请你计算,为了达成目标,至少需要进行多少次染色操作。
输入
第一行包含整数 n n n。
第二行包含 n n n 个整数 c 1 , c 2 , ⋯ , c n c_1,c_2,\cdots,c_n c1,c2,⋯,cn,其中 c i c_i ci 为节点 i i i 的颜色( 1 1 1 表示黑, 0 0 0 表示白)。
接下来 n − 1 n−1 n−1 行,每行包含两个整数 u i , v i u_i,v_i ui,vi,表示节点 u i u_i ui 和节点 v i v_i vi 之间存在一条边。
输出
一个整数,表示所需的最少染色操作次数。
思路
题目的考点是并查集+树的直径。
并查集
既然相同颜色的,互相联通的点可以在一次操作内进行染色操作(即全部由黑变白,或者由白变黑),那么不如将他们视作一个点。通过并查集可以实现这一点。
找直径
缩点后形成的树 T T T 中,相邻结点具有不同颜色。考虑树 T T T 的直径 D = max u , v ∈ T d ( u , v ) D=\max_{u,v\in T}d(u,v) D=maxu,v∈Td(u,v),则必有一条长为 D D D 的路 p p p。
- 当 D D D 为奇数时,这条路由 D − 1 2 \frac{D-1}{2} 2D−1 个黑点(或者白点)和 D + 1 2 \frac{D+1}{2} 2D+1 个白点(或者黑点)组成,最少通过 D − 1 2 = ⌊ D / 2 ⌋ \frac{D-1}{2}=\lfloor D/2\rfloor 2D−1=⌊D/2⌋ 次操作将黑点全部换成白点(或者白点全部换成黑点)即可将这条路变成同一种颜色。
- 当 D D D 为偶数时,这条路由 D / 2 D/2 D/2 个黑点和 D / 2 D/2 D/2 个白点(或者黑点)组成,最少通过 D / 2 = ⌊ D / 2 ⌋ D/2=\lfloor D/2\rfloor D/2=⌊D/2⌋ 次操作将黑点全部换成白点(或者白点全部换成黑点)即可将这条路变成同一种颜色。
因此,如果要将树 T T T 变为一种颜色,至少要将这条路 p p p 变成一种颜色,次数 a n s ≥ ⌊ D / 2 ⌋ \mathrm{ans}\geq \lfloor D/2\rfloor ans≥⌊D/2⌋。此外,我们还能知道,从这条路的中心出发,通过 ⌊ D / 2 ⌋ \lfloor D/2\rfloor ⌊D/2⌋ 次操作,还可以将 T − p T-p T−p (即路外其他结点)也转变为一种颜色。因为如果做不到,说明我们在找直径的时候就找错了。
所以答案就是 a n s = ⌊ D / 2 ⌋ \mathrm{ans}=\lfloor D/2\rfloor ans=⌊D/2⌋。对于一个数 T T T 而言,它的直径为 D = f ( T ) D=f(T) D=f(T),其中
f ( T ) = max ( 1 + d T 1 + d T 2 , f ( T 1 ) , f ( T 2 ) ) f(T)=\max(1+d_{T_1}+d_{T_2},f(T_1),f(T_2)) f(T)=max(1+dT1+dT2,f(T1),f(T2))
d T d_T dT 表示树 T T T 的深度,而 T 1 , T 2 T_1,T_2 T1,T2 表示 T T T 最深的两个子树。
Code
# include <iostream>
# include <cstring>
# include <vector> using namespace std;int n,dad[200005],dp[200005],uu[200005],vv[200005];
bool c[200005];
vector<int> nex[200005];int getdad(int node){if(dad[node] == node)return node;return dad[node] = getdad(dad[node]);
}int getdp(int node,int fa){int mdp = 0;for(int &x : nex[node])if(x != fa)mdp = max(mdp,getdp(x,node));return dp[node] = mdp + 1;
}int maxdist(int node,int fa){int dp1 = -1,dp2 = -1,ans = 1;for(int &x : nex[node])if(x != fa){ans = max(ans,maxdist(x,node));if(dp[x] > dp1){if(dp1 < dp2) dp1 = dp[x];else dp2 = dp[x];}else dp2 = max(dp2,dp[x]);}if(dp1 == -1 && dp2 == -1) return ans;if(dp1 == -1) return max(ans,1 + dp2);else if(dp2 == -1) return max(ans,1 + dp1);return max(ans,1 + dp1 + dp2);
}int main(){int u,v;cin >> n;for(int i = 1;i <= n;i++)cin >> c[i];for(int i = 1;i < n;i++){cin >> u >> v;if(dad[u] && dad[v]){if(c[u] == c[v]) dad[getdad(v)] = getdad(u);}else if(dad[u]){dad[v] = c[u] == c[v]?getdad(u):v;}else if(dad[v]){dad[u] = c[u] == c[v]?getdad(v):u;}else{dad[u] = u;dad[v] = c[u] == c[v]?u:v;}uu[i] = u;vv[i] = v;}for(int i = 1;i < n;i++){u = getdad(uu[i]);v = getdad(vv[i]);if(c[u] != c[v]){nex[u].push_back(v);nex[v].push_back(u);}}getdp(getdad(1),0);return cout << maxdist(getdad(1),0) / 2,0;
}