algorithm_rs/
collection.rs

1use std::{collections::VecDeque, ops::Range};
2
3pub struct Bitset<T: Copy> {
4    curr: usize,
5    array: Vec<T>,
6    len: usize,
7}
8
9impl<T: Copy> Iterator for Bitset<T> {
10    type Item = Vec<T>;
11
12    fn next(&mut self) -> Option<Vec<T>> {
13        if self.curr == (1 << self.len) {
14            return None;
15        }
16
17        let mut ret = Vec::<T>::new();
18        for (i, &ai) in self.array.iter().enumerate() {
19            if (self.curr >> i & 1) == 1 {
20                ret.push(ai);
21            }
22        }
23
24        self.curr += 1;
25        Some(ret)
26    }
27}
28
29pub fn bitset<T: Copy>(a: Vec<T>) -> Bitset<T> {
30    let len = a.len();
31    Bitset { curr: 0, array: a, len }
32}
33
34#[cfg(test)]
35mod test_bitset {
36    use crate::collection::bitset;
37
38    #[test]
39    fn it_works() {
40        let mut bitset = bitset(vec![1, 2, 3]);
41        assert_eq!(bitset.next(), Some(vec![]));
42        assert_eq!(bitset.next(), Some(vec![1]));
43        assert_eq!(bitset.next(), Some(vec![2]));
44        assert_eq!(bitset.next(), Some(vec![1, 2]));
45        assert_eq!(bitset.next(), Some(vec![3]));
46        assert_eq!(bitset.next(), Some(vec![1, 3]));
47        assert_eq!(bitset.next(), Some(vec![2, 3]));
48        assert_eq!(bitset.next(), Some(vec![1, 2, 3]));
49        assert!(bitset.next().is_none());
50    }
51}
52
53#[derive(Debug)]
54pub enum Item {
55    Pre(usize),
56    Post(usize),
57}
58
59#[derive(PartialEq)]
60pub enum CollectionType {
61    Permutation,
62    Combination,
63}
64
65pub struct CollectionIter<'a> {
66    pub a: Range<usize>,
67    pub n: usize,
68    pub k: usize,
69    pub data: &'a Vec<usize>,
70    pub depth: usize,
71    pub stack: VecDeque<Item>,
72    pub permutation: Vec<usize>,
73    pub used: Vec<bool>,
74    pub collection_type: CollectionType,
75    pub allow_duplication: bool,
76}
77
78impl<'a> CollectionIter<'a> {
79    pub fn permutation(data: &'a Vec<usize>, allow_duplication: bool) -> Self {
80        let n: usize = data.len();
81        let a = 0..n;
82
83        let mut stack = VecDeque::new();
84        for i in a.clone().rev() {
85            stack.push_front(Item::Post(i));
86            stack.push_front(Item::Pre(i));
87        }
88
89        CollectionIter {
90            a,
91            n,
92            k: n,
93            data,
94            depth: 0,
95            stack,
96            permutation: vec![0; n],
97            used: vec![false; n],
98            collection_type: CollectionType::Permutation,
99            allow_duplication,
100        }
101    }
102
103    pub fn combination(data: &'a Vec<usize>, k: usize, allow_duplication: bool) -> Self {
104        let n: usize = data.len();
105        let a = 0..n;
106
107        let mut stack = VecDeque::new();
108        for i in a.clone().rev() {
109            stack.push_front(Item::Post(i));
110            stack.push_front(Item::Pre(i));
111        }
112
113        CollectionIter {
114            a,
115            n,
116            k,
117            data,
118            depth: 0,
119            stack,
120            permutation: vec![0; k],
121            used: vec![false; n],
122            collection_type: CollectionType::Combination,
123            allow_duplication,
124        }
125    }
126
127    pub fn should_skip(&self, ni: usize) -> bool {
128        let is_permutation = self.collection_type == CollectionType::Permutation;
129        is_permutation && !self.allow_duplication && self.used[ni]
130    }
131
132    pub fn lower(&self, ni: usize) -> usize {
133        match (&self.collection_type, self.allow_duplication) {
134            (&CollectionType::Permutation, true) => 0,
135            (&CollectionType::Permutation, false) => 0,
136            (&CollectionType::Combination, true) => ni,
137            (&CollectionType::Combination, false) => ni + 1,
138        }
139    }
140}
141
142impl Iterator for CollectionIter<'_> {
143    type Item = Vec<usize>;
144
145    fn next(&mut self) -> Option<Self::Item> {
146        while let Some(item_wrapper) = self.stack.pop_front() {
147            match item_wrapper {
148                Item::Pre(i) => {
149                    self.permutation[self.depth] = self.data[i];
150                    self.used[i] = true;
151                    self.depth += 1;
152
153                    if self.depth == self.k {
154                        return Some(self.permutation.clone());
155                    }
156
157                    for ni in (self.lower(i)..self.a.end).rev() {
158                        if self.should_skip(ni) {
159                            continue;
160                        }
161                        self.stack.push_front(Item::Post(ni));
162                        self.stack.push_front(Item::Pre(ni));
163                    }
164                }
165                Item::Post(i) => {
166                    self.depth -= 1;
167                    self.used[i] = false;
168                }
169            }
170        }
171
172        None
173    }
174}
175
176#[cfg(test)]
177mod test_iterator {
178    use crate::collection::CollectionIter;
179    fn check(iterator: CollectionIter, num_expected: usize, expected: Vec<Vec<usize>>) {
180        let mut num_count = 0;
181        for perm in iterator {
182            assert_eq!(expected[num_count], perm);
183            num_count += 1;
184        }
185        assert_eq!(num_count, num_expected);
186    }
187
188    mod with_duplication {
189        use crate::collection::{test_iterator::check, CollectionIter};
190
191        #[test]
192        fn it_works_permutation() {
193            let data = vec![1, 2, 4];
194            let num_expected = 27; // n^3
195            let iterator = CollectionIter::permutation(&data, true);
196            let expected = vec![
197                vec![1, 1, 1],
198                vec![1, 1, 2],
199                vec![1, 1, 4],
200                vec![1, 2, 1],
201                vec![1, 2, 2],
202                vec![1, 2, 4],
203                vec![1, 4, 1],
204                vec![1, 4, 2],
205                vec![1, 4, 4],
206                vec![2, 1, 1],
207                vec![2, 1, 2],
208                vec![2, 1, 4],
209                vec![2, 2, 1],
210                vec![2, 2, 2],
211                vec![2, 2, 4],
212                vec![2, 4, 1],
213                vec![2, 4, 2],
214                vec![2, 4, 4],
215                vec![4, 1, 1],
216                vec![4, 1, 2],
217                vec![4, 1, 4],
218                vec![4, 2, 1],
219                vec![4, 2, 2],
220                vec![4, 2, 4],
221                vec![4, 4, 1],
222                vec![4, 4, 2],
223                vec![4, 4, 4],
224            ];
225            check(iterator, num_expected, expected);
226        }
227
228        #[test]
229        fn it_works_combination() {
230            let k: usize = 3;
231            let data = vec![1, 2, 4];
232            let num_expected = 10; // c(n + k - 1, k)
233            let iterator = CollectionIter::combination(&data, k, true);
234
235            let expected = vec![
236                vec![1, 1, 1],
237                vec![1, 1, 2],
238                vec![1, 1, 4],
239                vec![1, 2, 2],
240                vec![1, 2, 4],
241                vec![1, 4, 4],
242                vec![2, 2, 2],
243                vec![2, 2, 4],
244                vec![2, 4, 4],
245                vec![4, 4, 4],
246            ];
247            check(iterator, num_expected, expected);
248        }
249    }
250
251    mod without_duplication {
252        use crate::collection::{test_iterator::check, CollectionIter};
253
254        #[test]
255        fn it_works_permutation() {
256            let data = vec![1, 2, 4, 8];
257            let num_expected = 24; // 4!
258            let iterator = CollectionIter::permutation(&data, false);
259            let expected = vec![
260                vec![1, 2, 4, 8],
261                vec![1, 2, 8, 4],
262                vec![1, 4, 2, 8],
263                vec![1, 4, 8, 2],
264                vec![1, 8, 2, 4],
265                vec![1, 8, 4, 2],
266                vec![2, 1, 4, 8],
267                vec![2, 1, 8, 4],
268                vec![2, 4, 1, 8],
269                vec![2, 4, 8, 1],
270                vec![2, 8, 1, 4],
271                vec![2, 8, 4, 1],
272                vec![4, 1, 2, 8],
273                vec![4, 1, 8, 2],
274                vec![4, 2, 1, 8],
275                vec![4, 2, 8, 1],
276                vec![4, 8, 1, 2],
277                vec![4, 8, 2, 1],
278                vec![8, 1, 2, 4],
279                vec![8, 1, 4, 2],
280                vec![8, 2, 1, 4],
281                vec![8, 2, 4, 1],
282                vec![8, 4, 1, 2],
283                vec![8, 4, 2, 1],
284            ];
285            check(iterator, num_expected, expected);
286        }
287
288        #[test]
289        fn it_works_combination() {
290            let k: usize = 3;
291            let num_expected = 4; // c(n, k)
292            let data = vec![1, 2, 4, 8];
293            let iterator = CollectionIter::combination(&data, k, false);
294            let expected = vec![vec![1, 2, 4], vec![1, 2, 8], vec![1, 4, 8], vec![2, 4, 8]];
295            check(iterator, num_expected, expected);
296        }
297    }
298}
299
300#[macro_export]
301macro_rules! ndarray {
302    // ndarray!(val; *shape)
303    ($x:expr;) => { $x };
304    ($x:expr; $size:expr $( , $rest:expr )*) => {
305        vec![ndarray!($x; $($rest),*); $size]
306    };
307}
308
309#[cfg(test)]
310mod test_ndarray {
311
312    #[test]
313    fn it_works() {
314        // ndarray!(val; 1) => [val]
315        assert_eq!(ndarray!(5; 1), vec![5]);
316        // ndarray!(val; 1, 2) => [[val, val]]
317        assert_eq!(ndarray!(5; 1, 2), vec![vec![5, 5]]);
318        // ndarray!(val; 2, 1) => [[val], [val]]
319        assert_eq!(ndarray!(5; 2, 1), vec![vec![5], vec![5]]);
320    }
321}
322
323#[derive(Debug, Clone)]
324pub struct UnionFind {
325    parents: Vec<usize>,
326    sizes: Vec<usize>,
327}
328
329#[allow(clippy::needless_range_loop)]
330impl UnionFind {
331    pub fn new(n: usize) -> Self {
332        Self {
333            parents: (0..n).collect(),
334            sizes: vec![1usize; n],
335        }
336    }
337
338    pub fn parent(&mut self, x: usize) -> usize {
339        if self.parents[x] == x {
340            x
341        } else {
342            self.parents[x] = self.parent(self.parents[x]);
343            self.parents[x]
344        }
345    }
346
347    pub fn unite(&mut self, x: usize, y: usize) {
348        let mut px = self.parent(x);
349        let mut py = self.parent(y);
350
351        if px == py {
352            return;
353        }
354
355        if self.sizes[px] < self.sizes[py] {
356            std::mem::swap(&mut px, &mut py);
357        }
358
359        self.sizes[px] += self.sizes[py];
360        self.parents[py] = px;
361    }
362
363    pub fn size(&mut self, x: usize) -> usize {
364        let x = self.parent(x);
365        self.sizes[x]
366    }
367
368    pub fn same(&mut self, x: usize, y: usize) -> bool {
369        let px = self.parent(x);
370        let py = self.parent(y);
371        px == py
372    }
373}
374
375#[cfg(test)]
376mod test_union_find {
377    use crate::collection::UnionFind;
378
379    // helper function
380    fn sizes(uf: &mut UnionFind, n: usize) -> Vec<usize> {
381        (0..n).map(|i| uf.size(i)).collect()
382    }
383
384    #[test]
385    fn it_works() {
386        let n: usize = 5;
387        let mut uf = UnionFind::new(n);
388        assert_eq!(sizes(&mut uf, n), [1, 1, 1, 1, 1]);
389
390        uf.unite(0, 1);
391        assert_eq!(uf.parent(0), uf.parent(1));
392        assert!(uf.same(0, 1));
393        assert_ne!(uf.parent(0), uf.parent(2));
394        assert!(!uf.same(0, 2));
395        assert_eq!(sizes(&mut uf, n), [2, 2, 1, 1, 1]);
396
397        // check noop
398        uf.unite(0, 1);
399        assert_eq!(uf.parent(0), uf.parent(1));
400        assert!(uf.same(0, 1));
401        assert_ne!(uf.parent(0), uf.parent(2));
402        assert!(!uf.same(0, 2));
403        assert_eq!(sizes(&mut uf, n), [2, 2, 1, 1, 1]);
404
405        uf.unite(0, 2);
406        assert_eq!(uf.parent(0), uf.parent(2));
407        assert!(uf.same(0, 2));
408        assert_eq!(sizes(&mut uf, n), [3, 3, 3, 1, 1]);
409
410        uf.unite(3, 4);
411        assert_ne!(uf.parent(0), uf.parent(3));
412        assert!(!uf.same(0, 3));
413        assert_eq!(sizes(&mut uf, n), [3, 3, 3, 2, 2]);
414
415        uf.unite(0, 3);
416        assert_eq!(uf.parent(0), uf.parent(3));
417        assert!(uf.same(0, 3));
418        assert_eq!(uf.parent(0), uf.parent(4));
419        assert!(uf.same(0, 4));
420        assert_eq!(sizes(&mut uf, n), [5, 5, 5, 5, 5]);
421    }
422}