• 使用内存映射加快PyTorch数据集的读取


    本文将介绍如何使用内存映射文件加快PyTorch数据集的加载速度

    在使用Pytorch训练神经网络时,最常见的与速度相关的瓶颈是数据加载的模块。如果我们将数据通过网络传输,除了预取和缓存之外,没有任何其他的简单优化方式。

    但是如果数据本地存储,我们可以通过将整个数据集组合成一个文件,然后映射到内存中来优化读取操作,这样我们每次文件读取数据时就不需要访问磁盘,而是从内存中直接读取可以加快运行速度。

    什么是内存映射文件

    内存映射文件(memory-mapped file)是将完整或者部分文件加载到内存中,这样就可以通过内存地址相关的load或者store指令来操纵文件。为了支持这个功能,现代的操作系统会提供一个叫做mmap的系统调用。这个系统调用会接收一个虚拟内存地址(VA),长度(len),protection,一些标志位,一个打开文件的文件描述符,和偏移量(offset)。

    由于虚拟内存代表的附加抽象层,我们可以映射比机器的物理内存容量大得多的文件。正在运行的进程所需的内存段(称为页)从外部存储中获取,并由虚拟内存管理器自动复制到主内存中。

    使用内存映射文件可以提高I/O性能,因为通过系统调用进行的普通读/写操作比在本地内存中进行更改要慢得多,对于操作系统来说,文件以一种“惰性”的方式加载,通常一次只加载一个页,因此即使对于较大的文件,实际RAM利用率也是最低的,但是使用内存映射文件可以改善这个流程。

    什么是PyTorch数据集

    Pytorch提供了用于在训练模型时处理数据管道的两个主要模块:Dataset和DataLoader。

    DataLoader主要用作Dataset的加载,它提供了许多可配置选项,如批处理、采样、预读取、变换等,并抽象了许多方法。

    Dataset是我们进行数据集处理的实际部分,在这里我们编写训练时读取数据的过程,包括将样本加载到内存和进行必要的转换。

    对于Dataset,必须实现:

    __init_
    
    • 1

    ,

    __len__
    
    • 1

    __getitem__
    
    • 1

    三个方法

    实现自定义数据集

    接下来,我们将看到上面提到的三个方法的实现。

    最重要的部分是在

    __init__
    
    • 1

    中,我们将使用 numpy 库中的

    np.memmap()
    
    • 1

    函数来创建一个ndarray将内存缓冲区映射到本地的文件。

    在数据集初始化时,将ndarray使用可迭代对象进行填充,代码如下:

    class MMAPDataset(Dataset):
        def __init__(
            self,
            input_iter: Iterable[np.ndarray],
            labels_iter: Iterable[np.ndarray],
            mmap_path: str = None,
            size: int = None,
            transform_fn: Callable[..., Any] = None,
            
        ) -> None:
            super().__init__()
    
            self.mmap_inputs: np.ndarray = None
            self.mmap_labels: np.ndarray = None
            self.transform_fn = transform_fn
    
            if mmap_path is None:
                mmap_path = os.path.abspath(os.getcwd())
            self._mkdir(mmap_path)
    
            self.mmap_input_path = os.path.join(mmap_path, DEFAULT_INPUT_FILE_NAME)
            self.mmap_labels_path = os.path.join(mmap_path, DEFAULT_LABELS_FILE_NAME)
            self.length = size
    
            for idx, (input, label) in enumerate(zip(input_iter, labels_iter)):
                if self.mmap_inputs is None:
                    self.mmap_inputs = self._init_mmap(
                        self.mmap_input_path, input.dtype, (self.length, *input.shape)
                    )
                    self.mmap_labels = self._init_mmap(
                        self.mmap_labels_path, label.dtype, (self.length, *label.shape)
                    )
    
                self.mmap_inputs[idx][:] = input[:]
                self.mmap_labels[idx][:] = label[:]
    
        def __getitem__(self, idx: int) -> Tuple[Union[np.ndarray, torch.Tensor]]:
            if self.transform_fn:
                return self.transform_fn(self.mmap_inputs[idx]), torch.tensor(self.mmap_labels[idx]) 
            return self.mmap_inputs[idx], self.mmap_labels[idx]
    
        def __len__(self) -> int:
            return self.length
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    我们在上面提供的代码中还使用了两个辅助函数。

    def _mkdir(self, path: str) -> None:
            if os.path.exists(path):
                return
    
            try:
                os.makedirs(os.path.dirname(path), exist_ok=True)
                return
            except:
                raise ValueError(
                    "Failed to create the path (check the user write permissions)."
                )
    
        def _init_mmap(self, path: str, dtype: np.dtype, shape: Tuple[int], remove_existing: bool = False) -> np.ndarray:
            open_mode = "r+"
    
            if remove_existing:
                open_mode = "w+"
            
            return np.memmap(
                path,
                dtype=dtype,
                mode=open_mode,
                shape=shape,
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    可以看到,上面我们自定义数据集与一般情况的主要区别就是

    _init_mmap
    
    • 1

    中调用的

    np.memmap()
    
    • 1

    ,所以这里我们对

    np.memmap()
    
    • 1

    做一个简单的解释:

    Numpy的memmap对象,它允许将大文件分成小段进行读写,而不是一次性将整个数组读入内存。memmap也拥有跟普通数组一样的方法,基本上只要是能用于ndarray的算法就也能用于memmap。

    使用函数

    np.memmap
    
    • 1

    并传入一个文件路径、数据类型、形状以及文件模式,即可创建一个新的memmap存储在磁盘上的二进制文件创建内存映射。

    对于更多的介绍请参考Numpy的文档,这里就不做详细的解释了

    基准测试

    为了实际展示性能提升,我将内存映射数据集实现与以经典方式读取文件的普通数据集实现进行了比较。这里使用的数据集由 350 张 jpg 图像组成。

    从下面的结果中,我们可以看到我们的数据集比普通数据集快 30 倍以上:

    总结

    本文中介绍的方法在加速Pytorch的数据读取是非常有效的,尤其是使用大文件时,但是这个方法需要很大的内存,在做离线训练时是没有问题的,因为我们能够完全的控制我们的数据,但是如果想在生产中应用还需要考虑使用,因为在生产中有些数据我们是无法控制的。

    最后Numpy的文档地址如下:

    https://numpy.org/doc/stable/reference/generated/numpy.memmap.html

    https://avoid.overfit.cn/post/33d9496e1f8440d69a220fe6b9ab700c

    作者:Tudor Surdoiu

  • 相关阅读:
    SpringMVC+Shiro的基本使用
    Java2 - 数据结构
    Python 10之异常模块包
    Python3 基础语法:行与缩进
    记一次clickhouse手动更改分片数异常
    深度学习之CNN宫颈癌预测
    大模型多跳推理有解啦,北大化繁为简,用30B模型击败百亿模型
    中国现货白银中的跳空形态
    关于Fragment的生命周期,你知道多少?
    读后感读后感读后感
  • 原文地址:https://blog.csdn.net/m0_46510245/article/details/126082350