• Rust实现基于Tokio的限制内存占用的channel


    Rust实现基于Tokio的限制内存占用的channel

    简介

    本文介绍如何基于tokio的channel实现一个限制内存占用的channel。

    Tokio提供了多种协程间同步的接口,用于在不同的协程中同步数据。
    常用的channel有两种:boundedunbounded,其中ubbounded的channel可以无限的发送数据,而bounded的channel则有限的发送数据。两种channel都没有对自身的内存占用做出限制。

    异步网络编程中常用一个channel连接两个task,其中业务task与业务交互:将要发送的数据发送到channel,而网络task与操作系统交互:从channel中接收数据并写入socket。单有时候带宽有限或者对端接收速率过慢时,而网络task从channel中接收的速度小于业务task向channel中发送的速度时,会造成大量的数据阻塞在channel中,如果不对channel的占用内存做限制,则会造成内存占用过多甚至进程被OOM

    实现

    1. 获取数据大小

      要想限制channel总的内存占用,必须要直到每个数据的大小。比较常见的作法是所有需要发送到channel的内容都必须实现一个Trait,此Trait中定义了一个get_size方法,用于获取数据的大小。

      pub trait GetSize {
       /// get total size
       fn get_size(&self) -> usize;
      }
      
      • 1
      • 2
      • 3
      • 4

      要发送的内容必须实现GetSize的Trait,并实现get_size方法。注意:get_size方法获取到的大小需包括栈空间和堆空间,例如:

       struct MyData {
           data: Vec<u8>,
       }
      
       impl GetSize for MyData {
           fn get_size(&self) -> usize {
               return std::mem::size_of::<MyData>() + self.data.len();//stack size + heap size
           }
       }
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
    2. 创建SizedSenderSizedReceiver

      SizedSenderSizedReceiver都可以基于tokio的UnboundedSenderUnboundedReceiver实现。在tokio的基础上,需要共享一个条件变量用于在sender和receiver之间同步当前是否还有可用空间。

         
      pub struct SizedSender<T: GetSize> {
          inner: mpsc::UnboundedSender<T>,
          size_semaphore: Arc<(Semaphore, usize)>,
      }   
      
      
      pub struct SizedReceiver<T: GetSize> {
          inner: mpsc::UnboundedReceiver<T>,
          size_semaphore: Arc<(Semaphore, usize)>,
      }
      
      
      /// Limit space usage but not limit the number of messages, bytes_size must bigger than 0.
      pub fn sized_channel<T: GetSize>(bytes_size: usize) -> (SizedSender<T>, SizedReceiver<T>) {
          let (tx, rx) = mpsc::unbounded_channel::<T>();
          let semaphore = Arc::new((Semaphore::new(bytes_size), bytes_size));
          (
              SizedSender::new(tx, semaphore.clone()),
              SizedReceiver::new(rx, semaphore),
          )
      }          
      
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
    3. SizedSender实现

      发送端发送时需要调用get_size方法获取数据的大小,然后调用Semaphore::available_permits方法获取可用空间,如果可用空间大于数据大小,则发送成功,否则发送失败。

      impl<T: GetSize> SizedSender<T> {
       pub fn new(inner: mpsc::UnboundedSender<T>, size_semaphore: Arc<(Semaphore, usize)>) -> Self {
           Self {
               inner,
               size_semaphore,
           }
       }
      
       fn do_send(
           &self,
           message: T,
           permits: Option<SemaphorePermit<'_>>,
       ) -> Result<(), SendError<T>> {
           match self.inner.send(message) {
               Ok(r) => {
                   if let Some(permits) = permits {
                       permits.forget();
                   }
      
                   Ok(r)
               }
               Err(e) => {
                   log::debug!("send value error!");
                   Err(e)
               }
           }
       }
       pub async fn send(&self, message: T) -> Result<(), SendError<T>> {
           let message_size = message.get_size();
      
           if message_size > self.size_semaphore.1 {
               return Err(SendError(message));
           }
           let size = match u32::try_from(message_size) {
               Ok(size) => size,
               Err(_) => {
                   return Err(SendError(message));
               }
           };
      
           if self.size_semaphore.0.available_permits() < size as usize {
               // The buffer is about to be depleted, sending may be blocked.
           }
      
           let permits = match self.size_semaphore.0.acquire_many(size).await {
               Ok(perimits) => Some(perimits),
               Err(_) => {
                   return Err(SendError(message));
               }
           };
      
           self.do_send(message, permits)
       }
       }
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
      • 45
      • 46
      • 47
      • 48
      • 49
      • 50
      • 51
      • 52
      • 53
      • 54
    4. SizedReceiver的实现

      接收端接收时需要调用get_size方法获取数据的大小,然后将相应大小的permits还给信号量即可。

      impl<T: GetSize> SizedReceiver<T> {
      pub fn new(inner: mpsc::UnboundedReceiver<T>, size_semaphore: Arc<(Semaphore, usize)>) -> Self {
          Self {
              inner,
              size_semaphore,
          }
      }
      
      pub async fn recv(&mut self) -> Option<T> {
          self.inner.recv().await.map(|r| {
              let message_size = r.get_size();
      
              self.size_semaphore.0.add_permits(message_size);
      
              r
          })
      }
      }
      
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
    5. 其他

      在上述实现的基础上,还可以实现更多方法,比如try_sendtry_recv等。

  • 相关阅读:
    【JVM】内存模型:原子性、可见性、有序性的问题引出与解决
    Object转List<>,转List<Map<>>
    C++ 带你吃透string容器的使用
    神经网络卷积层
    python提取网页指定内容
    聚甲基丙烯酰氧乙基三甲基氯化铵(Poly-MAC)表面接枝聚苯乙烯树脂微球相关研究
    共享内存和信号量的配合机制
    【面向对象】【0x00】 Python面向对象介绍
    Airtest 点击按钮前后两张图片的相似度,判断按钮是否可以被点击
    【Vue3】定义全局变量和全局函数
  • 原文地址:https://blog.csdn.net/luchengtao11/article/details/134020454