• Pytorch使用torch.utils.data.random_split拆分数据集,拆分后的数据集状况


    对于这个API,我最开始的预想是从 '猫1猫2猫3猫4狗1狗2狗3狗4' 中分割出 '猫1猫2狗4狗1' 和 '猫4猫3狗2狗3' ,但是打印结果和我预想的不一样

    数据集文件的存放路径如下图

    测试代码如下

    1. import torch
    2. import torchvision
    3. transform = torchvision.transforms.Compose([
    4. torchvision.transforms.Resize((512,512)), # 调整图像大小为 224x224
    5. torchvision.transforms.ToTensor(), # 转换为张量
    6. torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
    7. ])
    8. dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',
    9. transform=transform)
    10. val_ratio = 0.2
    11. val_size = int(len(dataset) * val_ratio)
    12. train_size = len(dataset) - val_size
    13. train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    14. cats_num = 0
    15. dogs_num = 0
    16. for x,y in train_dataset:
    17. if y == 0:
    18. cats_num += 1
    19. else:
    20. dogs_num += 1
    21. print("cats_num: ",cats_num)
    22. print("dogs_num: ",dogs_num)
    23. cats_num2 = 0
    24. dogs_num2 = 0
    25. for x,y in val_dataset:
    26. if y == 0:
    27. cats_num2 += 1
    28. else:
    29. dogs_num2 += 1
    30. print("cats_num2: ",cats_num2)
    31. print("dogs_num2: ",dogs_num2)

    输出如下

    可以看到总共25000张图片的数据集,分割后并不是cats_num:10000,dogs_num:10000,cats_num2:2500,dogs_num2:2500

    也就是说,分割后的状况是猫狗的数量并不一定相等,如结果为 '猫1猫2猫4狗1' 和 '狗4猫3狗2狗3'

  • 相关阅读:
    Redis之key命令
    Android之getSystemService方法实现详解
    如何跟踪网络路由链路&检测网络健康状况
    JWT学习
    [附源码]Python计算机毕业设计Django酒店客房管理系统
    rosjava零散
    React-4 组件知识
    Aethir推出其首次去中心化AI节点售卖
    Flink架构&重要概念解析-超详理解
    预定2.0 Crack ZoomCharts JavaScript 最值得探索
  • 原文地址:https://blog.csdn.net/Victor_Li_/article/details/134035619