algorithm_rs/
tree.rs

1use std::ops::Add;
2
3pub trait Zero {
4    fn zero() -> Self;
5}
6
7pub trait Bounded {
8    fn min_value() -> Self;
9    fn max_value() -> Self;
10}
11
12macro_rules! impl_monoid_traits {
13    ($t:ty) => {
14        impl Zero for $t {
15            fn zero() -> Self {
16                0
17            }
18        }
19        impl Bounded for $t {
20            fn min_value() -> Self {
21                <$t>::MIN
22            }
23            fn max_value() -> Self {
24                <$t>::MAX
25            }
26        }
27    };
28}
29
30impl_monoid_traits!(i8);
31impl_monoid_traits!(i16);
32impl_monoid_traits!(i32);
33impl_monoid_traits!(i64);
34impl_monoid_traits!(i128);
35impl_monoid_traits!(u8);
36impl_monoid_traits!(u16);
37impl_monoid_traits!(u32);
38impl_monoid_traits!(u64);
39impl_monoid_traits!(u128);
40impl_monoid_traits!(isize);
41impl_monoid_traits!(usize);
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum Op {
45    Max,
46    Min,
47    Add,
48}
49
50#[derive(Debug, Clone)]
51pub struct RangeGetTree<T> {
52    data: Vec<T>,
53    op: Op,
54}
55
56impl<T> RangeGetTree<T>
57where
58    T: Clone + Add<Output = T> + Ord + Zero + Bounded,
59{
60    const SEQ_LEN: usize = 1 << 20;
61
62    pub fn new(op: Op) -> Self {
63        let identity = Self::identity_for(&op);
64        Self {
65            data: vec![identity; 2 * Self::SEQ_LEN],
66            op,
67        }
68    }
69
70    fn identity_for(op: &Op) -> T {
71        match op {
72            Op::Add => T::zero(),
73            Op::Max => T::min_value(),
74            Op::Min => T::max_value(),
75        }
76    }
77
78    pub fn identity(&self) -> T {
79        Self::identity_for(&self.op)
80    }
81
82    fn operate(&self, a: T, b: T) -> T {
83        match &self.op {
84            Op::Add => a + b,
85            Op::Max => a.max(b),
86            Op::Min => a.min(b),
87        }
88    }
89
90    pub fn get(&self, l: usize, r: usize) -> T {
91        self.range_query_recursive(l, r, 0, Self::SEQ_LEN, 1)
92    }
93
94    pub fn update(&mut self, mut index: usize, value: T) {
95        index += Self::SEQ_LEN;
96        self.data[index] = self.operate(self.data[index].clone(), value);
97        while index > 1 {
98            index /= 2;
99            let lv = self.data[index * 2].clone();
100            let rv = self.data[index * 2 + 1].clone();
101            self.data[index] = self.operate(lv, rv);
102        }
103    }
104
105    fn range_query_recursive(&self, ql: usize, qr: usize, sl: usize, sr: usize, pos: usize) -> T {
106        if qr <= sl || sr <= ql {
107            return self.identity();
108        }
109        if ql <= sl && sr <= qr {
110            return self.data[pos].clone();
111        }
112        let sm = (sl + sr) / 2;
113        let lv = self.range_query_recursive(ql, qr, sl, sm, pos * 2);
114        let rv = self.range_query_recursive(ql, qr, sm, sr, pos * 2 + 1);
115        self.operate(lv, rv)
116    }
117}
118
119#[cfg(test)]
120mod test_range_get_tree {
121    use super::{Op, RangeGetTree};
122
123    #[test]
124    fn it_works_rsq() {
125        let mut rsq: RangeGetTree<i64> = RangeGetTree::new(Op::Add);
126        rsq.update(0, 3);
127        rsq.update(2, 3);
128        rsq.update(3, 1);
129        rsq.update(4, 4);
130        assert_eq!(rsq.get(0, 3), 6);
131        assert_eq!(rsq.get(1, 3), 3);
132        assert_eq!(rsq.get(2, 4), 4);
133        assert_eq!(rsq.get(3, 4), 1);
134        assert_eq!(rsq.get(1, 6), 8);
135        assert_eq!(rsq.get(0, 0), rsq.identity());
136    }
137
138    #[test]
139    fn it_works_rmaxq() {
140        let mut rmaxq: RangeGetTree<i64> = RangeGetTree::new(Op::Max);
141        rmaxq.update(0, 10);
142        rmaxq.update(2, 101);
143        rmaxq.update(100, 1001);
144        assert_eq!(rmaxq.get(0, 1), 10);
145        assert_eq!(rmaxq.get(0, 2), 10);
146        assert_eq!(rmaxq.get(0, 3), 101);
147        assert_eq!(rmaxq.get(0, 100100), 1001);
148        assert_eq!(rmaxq.get(0, 0), rmaxq.identity());
149    }
150
151    #[test]
152    fn it_works_rminq() {
153        let mut rminq: RangeGetTree<i64> = RangeGetTree::new(Op::Min);
154        rminq.update(0, 101);
155        rminq.update(2, 10);
156        rminq.update(100, 1001);
157        assert_eq!(rminq.get(0, 1), 101);
158        assert_eq!(rminq.get(0, 2), 101);
159        assert_eq!(rminq.get(0, 3), 10);
160        assert_eq!(rminq.get(0, 100100), 10);
161        assert_eq!(rminq.get(0, 0), rminq.identity());
162    }
163}
164
165#[derive(Debug, Clone)]
166pub struct RangeUpdateTree<T> {
167    data: Vec<T>,
168    op: Op,
169}
170
171impl<T> RangeUpdateTree<T>
172where
173    T: Clone + Add<Output = T> + Ord + Zero + Bounded,
174{
175    const SEQ_LEN: usize = 1 << 20;
176
177    pub fn new(op: Op) -> Self {
178        let identity = Self::identity_for(&op);
179        Self {
180            data: vec![identity; 2 * Self::SEQ_LEN],
181            op,
182        }
183    }
184
185    fn identity_for(op: &Op) -> T {
186        match op {
187            Op::Add => T::zero(),
188            _ => panic!("Unsupported op for RangeUpdateTree: {:?}", op),
189        }
190    }
191
192    pub fn identity(&self) -> T {
193        Self::identity_for(&self.op)
194    }
195
196    fn operate(&self, a: T, b: T) -> T {
197        match &self.op {
198            Op::Add => a + b,
199            _ => panic!("Unsupported op for RangeUpdateTree: {:?}", &self.op),
200        }
201    }
202
203    pub fn get(&self, mut index: usize) -> T {
204        index += Self::SEQ_LEN;
205        let mut ret = self.operate(self.identity(), self.data[index].clone());
206        while index > 1 {
207            index /= 2;
208            ret = self.operate(ret, self.data[index].clone());
209        }
210        ret
211    }
212
213    pub fn update(&mut self, mut l: usize, mut r: usize, value: T) {
214        l += Self::SEQ_LEN;
215        r += Self::SEQ_LEN;
216        while l < r {
217            if l % 2 == 1 {
218                self.data[l] = self.operate(self.data[l].clone(), value.clone());
219                l += 1;
220            }
221            l /= 2;
222            if r % 2 == 1 {
223                self.data[r - 1] = self.operate(self.data[r - 1].clone(), value.clone());
224                r -= 1;
225            }
226            r /= 2;
227        }
228    }
229}
230
231#[cfg(test)]
232mod test_range_update_tree {
233    use super::{Op, RangeUpdateTree};
234
235    #[test]
236    fn it_works_raq() {
237        let mut raq: RangeUpdateTree<i64> = RangeUpdateTree::new(Op::Add);
238        raq.update(1, 2, 1);
239        raq.update(2, 4, 2);
240        raq.update(3, 4, 3);
241        assert_eq!(raq.get(0), raq.identity());
242        assert_eq!(raq.get(2), 2);
243        assert_eq!(raq.get(3), 5);
244    }
245}