• Pytorch使用DataLoader, num_workers!=0时的内存泄露


    • 描述一下背景,和遇到的问题:

    我在做一个超大数据集的多分类,设备Ubuntu 22.04+i9 13900K+Nvidia 4090+64GB RAM,第一次的训练的训练集有700万张,训练成功。后面收集到更多数据集,数据增强后达到了1000万张。但第二次训练4个小时后,就被系统杀掉进程了,原因是Out of Memory。找了很久的原因,发现内存随着训练step的增加而线性增加,猜测是内存泄露,最后定位到了DataLoader的num_workers参数(只要num_workers=0就没有问题)。

    • 真正原因:

    Python(Pytorch)中的list转换成tensor时,会发生内存泄漏,要避免list的使用,可以通过使用np.array来代替list。

    • 解决办法:

    自定义DataLoader中的Dataset类,然后Dataset类中的list全部用np.array来代替。这样的话,DataLoader将np.array转换成Tensor的过程就不会发生内存泄露。

    • 下面给两个错误的示例代码和一个正确的代码:(都是我自己犯过的错误)

    1.错误的DataLoader加载数据集方法1

    1. # 加载数据
    2. train_data = datasets.ImageFolder(root=TRAIN_DIR_ARG, transform=transform)
    3. valid_data = datasets.ImageFolder(root=VALIDATION_DIR, transform=transform)
    4. test_data = datasets.ImageFolder(root=TEST_DIR, transform=transform)
    5. train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
    6. valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
    7. test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

    2.错误的DataLoader加载数据集方法2(重写了Dataset方法)

    1. class CustomDataset(Dataset):
    2. def __init__(self, data_dir, transform=None):
    3. self.data_dir = data_dir
    4. self.transform = transform
    5. self.image_paths = []
    6. self.labels = []
    7. # 遍历数据目录并收集图像文件路径和对应的标签
    8. classes = os.listdir(data_dir)
    9. for i, class_name in enumerate(classes):
    10. class_dir = os.path.join(data_dir, class_name)
    11. if os.path.isdir(class_dir):
    12. for image_name in os.listdir(class_dir):
    13. image_path = os.path.join(class_dir, image_name)
    14. self.image_paths.append(image_path)
    15. self.labels.append(i)
    16. def __len__(self):
    17. return len(self.image_paths)
    18. def __getitem__(self, idx):
    19. image_path = self.image_paths[idx]
    20. label = self.labels[idx]
    21. # # 在需要时加载图像
    22. image = Image.open(image_path)
    23. if self.transform:
    24. image = self.transform(image)
    25. return image, label
    26. train_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
    27. valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
    28. test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)
    29. train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
    30. valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
    31. test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

    3.重写Dataset的正确方法(重写了Dataset方法,list全部转成np.array)

    1. class CustomDataset(Dataset):
    2. def __init__(self, data_dir, transform=None):
    3. self.data_dir = data_dir
    4. self.transform = transform
    5. self.image_paths = [] # 使用Python列表
    6. self.labels = [] # 使用Python列表
    7. # 遍历数据目录并收集图像文件路径和对应的标签
    8. classes = os.listdir(data_dir)
    9. for i, class_name in enumerate(classes):
    10. class_dir = os.path.join(data_dir, class_name)
    11. if os.path.isdir(class_dir):
    12. for image_name in os.listdir(class_dir):
    13. image_path = os.path.join(class_dir, image_name)
    14. self.image_paths.append(image_path) # 添加到Python列表
    15. self.labels.append(i) # 添加到Python列表
    16. # 转换为NumPy数组,这里就是解决内存泄露的关键代码
    17. self.image_paths = np.array(self.image_paths)
    18. self.labels = np.array(self.labels)
    19. def __len__(self):
    20. return len(self.image_paths)
    21. def __getitem__(self, idx):
    22. image_path = self.image_paths[idx]
    23. label = self.labels[idx]
    24. # 在需要时加载图像
    25. image = Image.open(image_path)
    26. if self.transform:
    27. image = self.transform(image)
    28. # 将图像数据转换为NumPy数组
    29. image = np.array(image)
    30. return image, label
    31. train_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
    32. valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
    33. test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)
    34. train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
    35. valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
    36. test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

  • 相关阅读:
    东哥套现,大佬隐退?
    测试环境要多少?从成本与效率说起
    mysql a表like b表的某个字段,mysql一个表的字段like另外一个表的字段
    Vue3:对ref、reactive的一个性能优化API
    Unity中Commpont类获取子物体的示例
    openvino安装踩坑笔记
    NTFS安全权限
    Linux:syslog()的使用和示例
    易云维®工厂能耗管理平台系统方案,保证运营质量,推动广东制造企业节能减排
    Spring命名空间
  • 原文地址:https://blog.csdn.net/deephacking/article/details/133662577