• 深度学习(PyTorch)——torchvision中的数据集使用方法


    B站UP主“我是土堆”视频内容

    torchvision简介
    torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。以下是torchvision的构成:

    torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
    torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
    torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
    torchvision.utils: 其他的一些有用的方法。
    torchvision.transforms
    torchvision.transforms主要是用于常见的一些图形变换。
    torchvision.transforms.Compose()类。这个类的主要作用是串联多个图片变换的操作。这个类的构造很简单:

    # 图像预处理步骤
    transform = transforms.Compose([
        transforms.Resize(96), # 缩放到 96 * 96 大小
        transforms.ToTensor(), # 转化为Tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
    ])

    torchvision.datasets
    torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集,有以下的一些数据集:

    MNISTCOCO
    Captions
    Detection
    LSUN
    ImageFolder
    Imagenet-12
    CIFAR
    STL10
    SVHN
    PhotoTour

    1. # Image processing
    2. img_transform = transforms.Compose([
    3. transforms.ToTensor(),
    4. transforms.Normalize((0.5,), (0.5,)),
    5. ])
    6. # MNIST dataset
    7. mnist = datasets.MNIST(
    8. root='./data/', train=True, transform=img_transform, download=True)
    9. # Data loader
    10. dataloader = torch.utils.data.DataLoader(
    11. dataset=mnist, batch_size=batch_size, shuffle=True)

    torchvision.models

    torchvision.models 中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用。

    torchvision.models模块的子模块中包含以下模型结构。

    AlexNet
    VGG
    ResNet
    SqueezeNet
    DenseNet

    1. import torchvision.models as models
    2. resnet18 = models.resnet18()
    3. alexnet = models.alexnet()
    4. squeezenet = models.squeezenet1_0()
    5. densenet = models.densenet_161()

    也可以通过使用 pretrained=True 来加载一个别人预训练好的模型

    1. import torchvision.models as models
    2. resnet18 = models.resnet18(pretrained=True)
    3. alexnet = models.alexnet(pretrained=True)

    下面是B站UP主“我是土堆”视频内容

    下面的pytorch的官方文档

     

    cifar10的数据集介绍如下 

     

     使用torchvision下载需要的数据集程序界面如下

     测试集第一个数据的输出如下,最后的数字3表示类别,3对应猫

     

     

     加入transforms,把图片数据转换成tensor数据

     如果数据集下载比较慢可以用迅雷下载,数据集的下载地址可以通过以下步骤去查找

     按住ctrl,点击cifar10

    复制url到迅雷当中去下载 

     

    程序如下: 

    1. import torchvision
    2. from torch.utils.tensorboard import SummaryWriter
    3. dataset_transform = torchvision.transforms.Compose([
    4. torchvision.transforms.ToTensor()
    5. ])
    6. train_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=True,transform=dataset_transform,download=True)
    7. test_set = torchvision.datasets.CIFAR10(root="./dataset_CIFAR10",train=False,transform=dataset_transform,download=True)
    8. # print(test_set[0])
    9. # print(test_set.classes)
    10. # img,target = test_set[0]
    11. # print(img)
    12. # print(target)
    13. # print(test_set.classes[target])
    14. # img.show()
    15. # print(test_set[0])
    16. writer = SummaryWriter("p10")
    17. for i in range(10):
    18. img,target = test_set[i]
    19. writer.add_image("test_set",img,i)
    20. writer.close()

    参考文献:

    https://blog.csdn.net/frighting_ing/article/details/121863387?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522166200606316781683929819%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=166200606316781683929819&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_click~default-2-121863387-null-null.142^v44^new_blog_pos_by_title&utm_term=torchvision&spm=1018.2226.3001.4187​​​​​​​ 

  • 相关阅读:
    JavaScript脚本操作CSS
    布隆过滤器在项目中的使用
    JDBC 和 DDT2
    华为OD机试真题-勾股数元组-2023年OD统一考试(B卷)
    前行不缀 未来可期,鸿蒙生态发展迈入全新阶段
    nginx之configure解析以及模板简介
    Eureka 概述与 Eureka Server 配置
    [TQLCTF 2022]simple_bypass
    cJson堆内存释放问题
    性能测试监控-java分析工具Arthas
  • 原文地址:https://blog.csdn.net/qq_42233059/article/details/126639835