• Rust实现线段树和懒标记


    参考各家代码,用Rust实现了线段树和懒标记。

    由于使用了泛型,很多操作都要用闭包自定义实现。

    看代码。

    // 线段树定义
    pub struct SegmentTreeClone>
    {
        pub data: Vec,
        tree: Vec<Option>,
        marker: Vec,                               //懒标记。
        query_op: Box<dyn Fn(T, T) -> T>, //查询时,对所有查询元素做的操作。比如加法,就是求区间的所有元素的和。
        marker_marker_op: Box<dyn Fn(T, T) -> T>, //marker加到marker上时,对marker的操作。通常我们要marker[i] += p; 来更新标记,但是泛型实现不了,并且考虑到有些用户有别的需求,所以用闭包包装。
        marker_t_op: Box<dyn Fn(T, T) -> T>, //marker应用到T时,对T的操作。考虑到有些用户有别的需求,所以用闭包包装。
        marker_mul_usize: Box<dyn Fn(T, usize) -> T>, //marker乘usize的方法。这个没法通过要求满足Mul trait自动实现。由于使用了泛型,连乘法都要交给闭包实现。。。
    }
    
    implClone + Default + Copy + PartialEq> SegmentTree {
        pub fn new(
            data: Vec,
            query_op: Box<dyn Fn(T, T) -> T>,
            marker_marker_op: Box<dyn Fn(T, T) -> T>,
            marker_t_op: Box<dyn Fn(T, T) -> T>,
            marker_mul_usize: Box<dyn Fn(T, usize) -> T>,
        ) -> Self {
            let data_len = data.len();
            let mut tr = Self {
                data,
                marker: vec![T::default(); 4 * data_len], //四倍原数据大小
                tree: vec![None; 4 * data_len],           //四倍原数据大小
                query_op,
                marker_marker_op,
                marker_t_op,
                marker_mul_usize,
            };
            tr.build();
            tr
        }
    
        #[inline]
        pub fn get(&self, index: usize) -> Option<&T> {
            self.data.get(index)
        }
    
        #[inline]
        pub fn len(&self) -> usize {
            self.data.len()
        }
    
        #[inline]
        fn left_child(index: usize) -> usize {
            2 * index + 1
        }
    
        #[inline]
        fn right_child(index: usize) -> usize {
            2 * index + 2
        }
    
        #[inline]
        fn build(&mut self) {
            self.build_segment_tree(0, 0, self.data.len() - 1);
        }
    
        // 递归Build
        fn build_segment_tree(&mut self, tree_index: usize, left: usize, right: usize) {
            if left == right {
                self.tree[tree_index] = Some(self.data[left]);
                return;
            }
            let left_tree_index = Self::left_child(tree_index);
            let right_tree_index = Self::right_child(tree_index);
            let mid = (right - left) / 2 + left;
            self.build_segment_tree(left_tree_index, left, mid);
            self.build_segment_tree(right_tree_index, mid + 1, right);
            // 左右子树数据处理方式
            if let Some(l) = self.tree[left_tree_index] {
                if let Some(r) = self.tree[right_tree_index] {
                    self.tree[tree_index] = Some((self.query_op)(l, r))
                }
            }
        }
    
        // 返回对线段树的全部元素做query_op操作的结果
        #[inline]
        pub fn query_all(&mut self) -> T {
            self.recursion_query(0, self.data.len() - 1, 0, 0, self.data.len() - 1)
        }
    
        // 返回对线段树的[l..r]范围全部元素做query_op操作的结果
        pub fn query(&mut self, l: usize, r: usize) -> Result'static str> {
            if l > self.data.len() || r > self.data.len() || l > r {
                return Err("索引错误");
            }
            if l == r {
                return Ok(self.data[l]);
            }
            Ok(self.recursion_query(l, r, 0, 0, self.data.len() - 1))
        }
    
        // 在index表示的[current_left,current_right]范围中查询[l..r]值
        fn recursion_query(
            &mut self,
            l: usize,
            r: usize,
            index: usize,
            current_left: usize,
            current_right: usize,
        ) -> T {
            if l > current_right || r < current_left {
                return T::default();
            }
            if l == current_left && r == current_right {
                if let Some(d) = self.tree[index] {
                    if l == r {
                        self.data[l] = d;
                    }
                    return d;
                }
                return T::default();
            }
            self.push_down(index, current_right - current_left + 1);
            let mid = current_left + (current_right - current_left) / 2;
            if l >= mid + 1 {
                return self.recursion_query(l, r, Self::right_child(index), mid + 1, current_right);
            } else if r <= mid {
                return self.recursion_query(l, r, Self::left_child(index), current_left, mid);
            }
            let l_res = self.recursion_query(l, mid, Self::left_child(index), current_left, mid);
            let r_res =
                self.recursion_query(mid + 1, r, Self::right_child(index), mid + 1, current_right);
            (self.query_op)(l_res, r_res)
        }
    
        // 更新index为val
        pub fn set(&mut self, index: usize, val: T) -> Result<(), &'static str> {
            if index >= self.data.len() {
                return Err("索引超过线段树长度");
            }
            // 更新数据
            self.data[index] = val;
            // 递归更新树
            self.recursion_set(0, 0, self.data.len() - 1, index, val);
            Ok(())
        }
    
        // 递归更新树
        fn recursion_set(&mut self, index_tree: usize, l: usize, r: usize, index: usize, val: T) {
            if l == r {
                self.tree[index_tree] = Some(val);
                return;
            }
            let mid = l + (r - l) / 2;
            let left_child = Self::left_child(index_tree);
            let right_child = Self::right_child(index_tree);
            if index >= mid + 1 {
                self.recursion_set(right_child, mid + 1, r, index, val);
            } else {
                self.recursion_set(left_child, l, mid, index, val);
            }
            // 左右子树数据求和
            if let Some(l_d) = self.tree[left_child] {
                if let Some(r_d) = self.tree[right_child] {
                    self.tree[index_tree] = Some((self.query_op)(l_d, r_d));
                }
            }
        }
    
        // 应用所有懒标记到data数组上
        #[inline]
        pub fn apply_marker_all(&mut self) {
            self.apply_marker_lr(0, self.data.len() - 1);
        }
    
        // 应用懒标记到[l:r]数据范围
        #[inline]
        pub fn apply_marker_lr(&mut self, l: usize, r: usize) {
            self.apply_marker(l, r, 0, 0, self.data.len() - 1);
        }
    
        fn apply_marker(
            &mut self,
            l: usize,
            r: usize,
            index: usize,
            current_l: usize,
            current_r: usize,
        ) {
            if current_l > r || current_r < l || r >= self.data.len() {
                return; // 区间无交集
            } else {
                // 与目标区间有交集,但不包含于其中
                if current_l == current_r {
                    if let Some(d) = self.tree[index] {
                        self.data[current_l] = d;
                    }
                    return;
                }
                let mid = (current_l + current_r) / 2;
                self.push_down(index, current_r - current_l + 1);
                self.apply_marker(l, r, Self::left_child(index), current_l, mid); // 递归地往下寻找
                self.apply_marker(l, r, Self::right_child(index), mid + 1, current_r);
                self.tree[index] = Some((self.query_op)(
                    self.tree[Self::left_child(index)].unwrap(),
                    self.tree[Self::right_child(index)].unwrap(),
                ));
                // 根据子节点更新当前节点的值
            }
        }
        #[inline]
        pub fn update_interval(&mut self, l: usize, r: usize, delta: T) {
            self.update(l, r, delta, 0, 0, self.data.len() - 1);
        }
    
        // 传递marker到下级
        fn push_down(&mut self, index: usize, len: usize) {
            self.marker[Self::left_child(index)] =
                (self.marker_marker_op)(self.marker[index], self.marker[Self::left_child(index)]); // 标记向下传递
            self.marker[Self::right_child(index)] =
                (self.marker_marker_op)(self.marker[index], self.marker[Self::right_child(index)]);
            if self.tree[Self::left_child(index)].is_some() {
                self.tree[Self::left_child(index)] = Some((self.marker_t_op)(
                    (self.marker_mul_usize)(self.marker[index], len - (len / 2)),
                    self.tree[Self::left_child(index)].unwrap(),
                ));
            }
            if self.tree[Self::right_child(index)].is_some() {
                self.tree[Self::right_child(index)] = Some((self.marker_t_op)(
                    (self.marker_mul_usize)(self.marker[index], len / 2),
                    self.tree[Self::right_child(index)].unwrap(),
                ));
            }
            self.marker[index] = T::default(); // 清除标记
        }
    
        fn update(
            &mut self,
            l: usize,
            r: usize,
            delta: T,
            index: usize,
            current_l: usize,
            current_r: usize,
        ) {
            if current_l > r || current_r < l {
                return; // 区间无交集
            } else if current_l >= l && current_r <= r {
                // 当前节点对应的区间包含在目标区间中
                if self.tree[index].is_some() {
                    // 更新当前区间的值
                    self.tree[index] = Some((self.query_op)(
                        self.tree[index].unwrap(),
                        (self.marker_mul_usize)(delta, current_r - current_l + 1),
                    ));
                }
                // 如果不是叶子节点
                if current_r > current_l {
                    // 给当前区间打上标记
                    self.marker[index] = (self.marker_marker_op)(delta, self.marker[index]);
                }
            } else {
                // 与目标区间有交集,但不包含于其中
                let mid = (current_l + current_r) / 2;
                self.push_down(index, current_r - current_l + 1);
                self.update(l, r, delta, Self::left_child(index), current_l, mid); // 递归地往下寻找
                self.update(l, r, delta, Self::right_child(index), mid + 1, current_r);
                self.tree[index] = Some((self.query_op)(
                    self.tree[Self::left_child(index)].unwrap(),
                    self.tree[Self::right_child(index)].unwrap(),
                )); // 根据子节点更新当前节点的值
            }
        }
    }
    
    fn main() {
        let mut tr: SegmentTree<i32> = SegmentTree::new(
            vec![1, 3, 4, 0, 0, 4, 5, 0],
            Box::new(|a, b| a + b),
            Box::new(|a, b| a + b),
            Box::new(|a, b| a + b),
            Box::new(|a, b| a * (b as i32)),
        );
        let _ = tr.set(1, 2); //点更新,即把data[1]设为2
        tr.update_interval(0, 2, -1); //区间更新,即[0:2]每个元素减1
        tr.update_interval(1, 3, 2); //区间更新,即[1:3]每个元素加2
        tr.apply_marker_all(); //应用全部marker到data数组
        println!("{}", tr.query_all()); //输出19,即全部元素的和
        println!("{:?}", tr.data); //输出[0, 3, 5, 2, 0, 4, 5, 0]
    }
    

    做一道题验证一下这个线段树的正确性,直接看我写的1589. 所有排列中的最大和题解即可(虽然这道题用差分数组最快,但是作为线段树验证还是很方便的)。

  • 相关阅读:
    Git学习笔记 - Idea集成GitHub、Gitee
    Lumen/Laravel - 数据库读写分离原理 - 探究
    MySQL根据备注查询表、字段
    javascirpt中,什么是词法,什么是词法作用域
    【C++】超详细入门 —— 一文带你搞懂const限定符
    GET 和 POST 有什么区别?
    如何在Android平板上远程连接Ubuntu服务器code-server进行代码开发?
    Time Series Data Augmentation for Deep Learning: A Survey 之论文阅读
    前端面试题-找出两个数组的不同值(差集)
    30_ue4进阶末日生存游戏开发[为道具添加碰撞系统]
  • 原文地址:https://www.cnblogs.com/mariocanfly/p/17938072