Skip to content

Commit 85892e1

Browse files
authored
Merge pull request #177 from ia7ck/union-find
union find
2 parents ee28e11 + 4692423 commit 85892e1

File tree

1 file changed

+118
-36
lines changed

1 file changed

+118
-36
lines changed

algo/union_find/src/lib.rs

Lines changed: 118 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,152 @@
11
/// Union Find はグラフの連結成分を管理します。
2+
#[derive(Clone, Debug)]
23
pub struct UnionFind {
3-
par: Vec<usize>,
4-
size: Vec<usize>,
4+
nodes: Vec<NodeKind>,
5+
groups: usize,
6+
}
7+
8+
#[derive(Clone, Copy, Debug)]
9+
enum NodeKind {
10+
Root { size: usize },
11+
Child { parent: usize },
512
}
613

714
impl UnionFind {
8-
/// グラフの頂点数 `n` を渡します
9-
pub fn new(n: usize) -> UnionFind {
10-
UnionFind {
11-
par: (0..n).collect(),
12-
size: vec![1; n],
15+
/// 頂点数を `n` として初期化します
16+
pub fn new(n: usize) -> Self {
17+
Self {
18+
nodes: vec![NodeKind::Root { size: 1 }; n],
19+
groups: n,
1320
}
1421
}
22+
1523
/// 頂点 `i` の属する連結成分の代表元を返します。
1624
///
1725
/// # Examples
26+
///
1827
/// ```
1928
/// use union_find::UnionFind;
2029
/// let mut uf = UnionFind::new(6);
2130
/// uf.unite(0, 1);
2231
/// uf.unite(1, 2);
2332
/// uf.unite(3, 4);
24-
/// let mut leaders = (0..6).map(|i| uf.find(i)).collect::<Vec<_>>();
25-
/// assert_eq!(leaders[0], leaders[0]);
26-
/// assert_eq!(leaders[0], leaders[1]);
27-
/// assert_eq!(leaders[1], leaders[2]);
28-
/// assert_eq!(leaders[0], leaders[2]);
29-
/// assert_eq!(leaders[3], leaders[4]);
30-
/// assert_ne!(leaders[0], leaders[3]);
31-
/// assert_ne!(leaders[0], leaders[5]);
33+
///
34+
/// // [(0, 1, 2), (3, 4), (5)]
35+
/// assert_eq!(uf.find(0), uf.find(0));
36+
/// assert_eq!(uf.find(0), uf.find(1));
37+
/// assert_eq!(uf.find(1), uf.find(2));
38+
/// assert_eq!(uf.find(0), uf.find(2));
39+
/// assert_eq!(uf.find(3), uf.find(4));
40+
///
41+
/// assert_ne!(uf.find(0), uf.find(3));
42+
/// assert_ne!(uf.find(0), uf.find(5));
3243
/// ```
3344
pub fn find(&mut self, i: usize) -> usize {
34-
if self.par[i] != i {
35-
self.par[i] = self.find(self.par[i]);
45+
assert!(i < self.nodes.len());
46+
47+
match self.nodes[i] {
48+
NodeKind::Root { .. } => i,
49+
NodeKind::Child { parent } => {
50+
let root = self.find(parent);
51+
if root == parent {
52+
// noop
53+
} else {
54+
// 経路圧縮
55+
self.nodes[i] = NodeKind::Child { parent: root };
56+
}
57+
root
58+
}
3659
}
37-
self.par[i]
3860
}
61+
3962
/// 頂点 `i` の属する連結成分と頂点 `j` の属する連結成分をつなげます。
40-
pub fn unite(&mut self, i: usize, j: usize) {
63+
///
64+
/// 呼び出し前に別の連結成分だった場合 true を、同じ連結成分だった場合 false を返します。
65+
///
66+
/// # Examples
67+
///
68+
/// ```
69+
/// use union_find::UnionFind;
70+
/// let mut uf = UnionFind::new(6);
71+
/// assert!(uf.unite(0, 1));
72+
/// assert!(uf.unite(1, 2));
73+
/// assert!(uf.unite(3, 4));
74+
///
75+
/// // [(0, 1, 2), (3, 4), (5)]
76+
/// assert!(!uf.unite(0, 2));
77+
/// assert!(!uf.unite(3, 3));
78+
///
79+
/// assert!(uf.unite(4, 5));
80+
/// ```
81+
pub fn unite(&mut self, i: usize, j: usize) -> bool {
4182
let i = self.find(i);
4283
let j = self.find(j);
4384
if i == j {
44-
return;
85+
return false;
4586
}
46-
let (i, j) = if self.size[i] >= self.size[j] {
47-
(i, j)
48-
} else {
49-
(j, i)
50-
};
51-
self.par[j] = i;
52-
self.size[i] += self.size[j];
87+
88+
match (self.nodes[i], self.nodes[j]) {
89+
(NodeKind::Root { size: i_size }, NodeKind::Root { size: j_size }) => {
90+
let total = i_size + j_size;
91+
// マージテク
92+
if i_size >= j_size {
93+
self.nodes[j] = NodeKind::Child { parent: i };
94+
self.nodes[i] = NodeKind::Root { size: total };
95+
} else {
96+
self.nodes[i] = NodeKind::Child { parent: j };
97+
self.nodes[j] = NodeKind::Root { size: total };
98+
}
99+
}
100+
_ => unreachable!(),
101+
}
102+
103+
self.groups -= 1;
104+
true
53105
}
106+
54107
/// 頂点 `i` の属する連結成分のサイズ (頂点数) を返します。
55108
///
56109
/// # Examples
110+
///
57111
/// ```
58112
/// use union_find::UnionFind;
59113
/// let mut uf = UnionFind::new(6);
60114
/// uf.unite(0, 1);
61115
/// uf.unite(1, 2);
62116
/// uf.unite(3, 4);
63-
/// assert_eq!(uf.get_size(0), 3);
64-
/// assert_eq!(uf.get_size(1), 3);
65-
/// assert_eq!(uf.get_size(2), 3);
66-
/// assert_eq!(uf.get_size(3), 2);
67-
/// assert_eq!(uf.get_size(4), 2);
68-
/// assert_eq!(uf.get_size(5), 1);
117+
///
118+
/// // [(0, 1, 2), (3, 4), (5)]
119+
/// assert_eq!(uf.size(0), 3);
120+
/// assert_eq!(uf.size(1), 3);
121+
/// assert_eq!(uf.size(2), 3);
122+
/// assert_eq!(uf.size(3), 2);
123+
/// assert_eq!(uf.size(4), 2);
124+
/// assert_eq!(uf.size(5), 1);
69125
/// ```
70-
pub fn get_size(&mut self, i: usize) -> usize {
71-
let p = self.find(i);
72-
self.size[p]
126+
pub fn size(&mut self, i: usize) -> usize {
127+
let root = self.find(i);
128+
match self.nodes[root] {
129+
NodeKind::Root { size } => size,
130+
_ => unreachable!(),
131+
}
73132
}
133+
74134
/// 頂点 `i` と頂点 `j` が同じ連結成分に属するかどうかを返します。
75135
///
76136
/// # Examples
137+
///
77138
/// ```
78139
/// use union_find::UnionFind;
79140
/// let mut uf = UnionFind::new(6);
80141
/// assert!(uf.same(0, 0));
81142
/// assert!(uf.same(3, 3));
82143
/// assert!(uf.same(5, 5));
144+
///
83145
/// uf.unite(0, 1);
84146
/// uf.unite(1, 2);
85147
/// uf.unite(3, 4);
148+
///
149+
/// // [(0, 1, 2), (3, 4), (5)]
86150
/// assert!(uf.same(0, 1));
87151
/// assert!(uf.same(1, 2));
88152
/// assert!(uf.same(0, 2));
@@ -91,4 +155,22 @@ impl UnionFind {
91155
pub fn same(&mut self, i: usize, j: usize) -> bool {
92156
self.find(i) == self.find(j)
93157
}
158+
159+
/// 連結成分数を返します。
160+
///
161+
/// # Examples
162+
///
163+
/// ```
164+
/// use union_find::UnionFind;
165+
/// let mut uf = UnionFind::new(6);
166+
/// uf.unite(0, 1);
167+
/// uf.unite(1, 2);
168+
/// uf.unite(3, 4);
169+
///
170+
/// // [(0, 1, 2), (3, 4), (5)]
171+
/// assert_eq!(uf.count_groups(), 3);
172+
/// ```
173+
pub fn count_groups(&self) -> usize {
174+
self.groups
175+
}
94176
}

0 commit comments

Comments
 (0)