1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#[derive(Debug, Clone)]
pub struct UnionFind {
    parents: Vec<usize>,
    sizes: Vec<usize>,
}

#[allow(clippy::needless_range_loop)]
impl UnionFind {
    pub fn new(n: usize) -> Self {
        Self {
            parents: (0..n).collect(),
            sizes: vec![1usize; n],
        }
    }

    pub fn parent(&mut self, x: usize) -> usize {
        if self.parents[x] == x {
            x
        } else {
            self.parents[x] = self.parent(self.parents[x]);
            self.parents[x]
        }
    }

    pub fn unite(&mut self, x: usize, y: usize) {
        let mut px = self.parent(x);
        let mut py = self.parent(y);

        if px == py {
            return;
        }

        if self.sizes[px] < self.sizes[py] {
            std::mem::swap(&mut px, &mut py);
        }

        self.sizes[px] += self.sizes[py];
        self.parents[py] = px;
    }

    pub fn size(&mut self, x: usize) -> usize {
        let x = self.parent(x);
        self.sizes[x]
    }

    pub fn same(&mut self, x: usize, y: usize) -> bool {
        let px = self.parent(x);
        let py = self.parent(y);
        px == py
    }
}

#[cfg(test)]
mod test_union_find {
    use crate::collection::union_find::UnionFind;

    // helper function
    fn sizes(uf: &mut UnionFind, n: usize) -> Vec<usize> {
        (0..n).map(|i| uf.size(i)).collect()
    }

    #[test]
    fn it_works() {
        let n: usize = 5;
        let mut uf = UnionFind::new(n);
        assert_eq!(sizes(&mut uf, n), [1, 1, 1, 1, 1]);

        uf.unite(0, 1);
        assert_eq!(uf.parent(0), uf.parent(1));
        assert!(uf.same(0, 1));
        assert_ne!(uf.parent(0), uf.parent(2));
        assert!(!uf.same(0, 2));
        assert_eq!(sizes(&mut uf, n), [2, 2, 1, 1, 1]);

        // check noop
        uf.unite(0, 1);
        assert_eq!(uf.parent(0), uf.parent(1));
        assert!(uf.same(0, 1));
        assert_ne!(uf.parent(0), uf.parent(2));
        assert!(!uf.same(0, 2));
        assert_eq!(sizes(&mut uf, n), [2, 2, 1, 1, 1]);

        uf.unite(0, 2);
        assert_eq!(uf.parent(0), uf.parent(2));
        assert!(uf.same(0, 2));
        assert_eq!(sizes(&mut uf, n), [3, 3, 3, 1, 1]);

        uf.unite(3, 4);
        assert_ne!(uf.parent(0), uf.parent(3));
        assert!(!uf.same(0, 3));
        assert_eq!(sizes(&mut uf, n), [3, 3, 3, 2, 2]);

        uf.unite(0, 3);
        assert_eq!(uf.parent(0), uf.parent(3));
        assert!(uf.same(0, 3));
        assert_eq!(uf.parent(0), uf.parent(4));
        assert!(uf.same(0, 4));
        assert_eq!(sizes(&mut uf, n), [5, 5, 5, 5, 5]);
    }
}