《算法竞赛·快冲300题》每日一题:“01树”
《算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。
文章目录
- 题目描述
- 题解
- C++代码
- Java代码
- Python代码
“ 01树” ,链接: http://oj.ecustacm.cn/problem.php?id=1715
题目描述
【题目描述】 现在给你一个n个节点的树,而且每个节点有一个权值为0或者1。
现在有m次询问,每次询问输入两个节点x和y,以及一个权值k。
请你判断x和y的路径中是否存在权值为k的点。(包括x和y本身)
【输入格式】 输入第一行为两个正整数n和m,均为不超过10^5次方的正整数。
第二行是一个长度为n的01字符串,表示从节点1到节点n的权值。
接下来n-1行,每行两个数字u和v,表示节点u和v之间存在边。
接下来m行,每行输入三个数字x,y,k。其中x,y不相同,k为0或者1。 。
【输出格式】 对于每一次询问,如果x和y的路径中包含权值为k的点,输出Yes,否则输出No 。
【输入样例】
5 5
11010
1 2
2 3
2 4
1 5
1 4 1
1 4 0
1 3 0
1 3 1
5 5 1
【输出样例】
Yes
No
Yes
Yes
No
题解
本题简单的做法是先建树,然后每次查询用DFS搜索路径。任意两点之间有且只有一条路径,做一次DFS能找到这条路径,计算量O(n)。一共做m次查询,总复杂度O(mn),超时。
不过,本题特殊在于每个点的权值是0或1,查询也是查有没有等于0或1的点。查询一条路径时,如果能确定所有点都是1,或所有点都是0,或有0有1,那么就得到了答案。
把所有点按0和1分成多个子集,其中一些连通的1是一个子集,一些连通的0是一个子集。最后把整棵树分成很多权值为1的子集、权值为0的子集。权值为0的子集和权值为1的子集相邻。
对一个查询“x,y,k”:
(1)如果{x,y}属于一个子集,它们必然连通,且权值相同,权值为0或1。
(2)如果{x,y}不属于一个子集,它们要么是相邻的两个不同权值的子集,要么它们之间的路径穿过了一个不同权值的子集,两种情况下的路径上有1也有0。
以上讨论的实际上是并查集的操作。下面用带路径压缩的并查集编码,一次查询约为O(1),m次查询的总复杂度约为O(m)。。
【笔记】 。
C++代码
#include<bits/stdc++.h>
using namespace std;
char str[100010];
int s[100010]; //并查集
int find_set(int x){ //查询并查集,返回x的根if(x != s[x]) s[x] = find_set(s[x]); //路径压缩return s[x];
}
void merge_set(int x, int y){ //合并x = find_set(x); y = find_set(y);if(x != y) s[x] = s[y]; //把x合并到y上,y的根成为x的根
}
int main(){int n, m;scanf("%d %d",&n,&m);scanf("%s",str+1);for(int i = 1; i <= n; i++) s[i] = i; //并查集初始化for(int i = 1; i < n; i++){int u, v; scanf("%d %d",&u,&v);if(str[u] == str[v]) merge_set(u,v); //合并}for(int i = 1; i <= m; i++){int x, y; char k; scanf("%d %d %c",&x,&y,&k);if(find_set(x) == find_set(y) && str[x] != k) //属于同一个子集,且权值不等于kputs("No"); //比cout快else //其他情况,既有0也有1puts("Yes"); //比cout快}return 0;
}
Java代码
import java.util.Scanner;
public class Main {static char[] str = new char[100010];static int[] s = new int[100010];static int findSet(int x) {if (x != s[x]) s[x] = findSet(s[x]);return s[x];}static void mergeSet(int x, int y) {x = findSet(x);y = findSet(y);if (x != y) s[x] = s[y];}public static void main(String[] args) {Scanner sc = new Scanner(System.in);int n = sc.nextInt();int m = sc.nextInt();String strInput = sc.next();strInput.getChars(0, strInput.length(), str, 1);for (int i = 1; i <= n; i++) s[i] = i;for (int i = 1; i < n; i++) {int u = sc.nextInt();int v = sc.nextInt();if (str[u] == str[v]) mergeSet(u, v);}for (int i = 1; i <= m; i++) {int x = sc.nextInt();int y = sc.nextInt();char k = sc.next().charAt(0);if (findSet(x) == findSet(y) && str[x] != k) System.out.println("No");else System.out.println("Yes");}}
}
Python代码
import sys
sys.setrecursionlimit(1000000) #注意要扩栈
str = [0] * 100010
s = [0] * 100010
def find_set(x):if x != s[x]: s[x] = find_set(s[x])return s[x]
def merge_set(x, y):x = find_set(x)y = find_set(y)if x != y: s[x] = s[y]
n, m = map(int, input().split())
str[1:n+1] = input()
for i in range(1, n+1): s[i] = i
for i in range(n-1):u, v = map(int, input().split())if str[u] == str[v]: merge_set(u, v)
for i in range(m):x, y, k = input().split()x = int(x)y = int(y)if find_set(x) == find_set(y) and str[x] != k: print("No")else: print("Yes")