• 【torchvision】torchvision 介绍


    官网地址:https://pytorch.org/vision/stable/index.html

    Torchvision 是 PyTorch 的一个视觉处理工具包,独立于PyTorch,需要另外安装

    它包括4个类,各类的主要功能如下:

    • 1)datasets:提供常用的数据集加载,设计上都是继承自torch.utils.data.Dataset,主要
      包括MMIST、CIFAR10/100、ImageNet和COCO等。
    • 2)models:提供深度学习中各种经典的网络结构以及训练好的模型(参数选择
      pretrained=True),包括AlexNet、VGG系列、ResNet系列、Inception系列等。
    • 3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。
    • 4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是 save_img,它能将 Tensor 保存成图片
      在这里插入图片描述

    1、torchvision.datasets

    1.1 常用数据集加载 MNIST等

    举例,通过torchvision下载 MNIST (mnist 全称:mixed national institute of standards and technology database)

    train_dataset = torchvision.datasets.MNIST(root, 
                                               train=True, 
                                               transform=transform, 
                                               download=True)
    
    • 1
    • 2
    • 3
    • 4

    root :需要下载至地址的根目录位置
    train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt; 默认是True
    transform:一系列作用在PIL图片上的转换操作,返回一个转换后的版本
    download:是否下载到 root指定的位置,如果指定的root位置已经存在该数据集,则不再下载


    1.2 自定义数据集读取 ImageFolder

    torchvision.datasets.ImageFolder(root, transform, target_transform, loader)
    
    • 1
    • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是 “…/input/data/”
    • transform:对图片进行预处理操作(函数),原始图片作为输入,返回一个转换后的图片。
    • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
    • loader:表示数据集加载方式,通常默认加载方式即可

    另外,该 API 有以下成员变量:

    • self.classes:用一个 list 保存类别名称
    • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
    • self.imgs:保存(img-path, class) tuple的 list,与我们自定义 Dataset类的 def getitem(self, index): 返回值类似。注意看下面实例中 dataset.imgs 的返回值

    举例:数据存储结构如下

    在这里插入图片描述

    import torchvision
    from torchvision import transforms, utils
    
    trans = transforms.Compose([transforms.RandomCrop(400), transforms.ToTensor()])
    dataset = torchvision.datasets.ImageFolder('/Users/manmi/Desktop/data/data', transform=trans)
    
    print(dataset.classes)   # ['bird', 'cat', 'dog']
    print(dataset.class_to_idx)   # {'bird': 0, 'cat': 1, 'dog': 2}
    print(dataset.imgs)   # [('/Users/manmi/Desktop/data/data/bird/bird1.jpeg', 0), ('/Users/manmi/Desktop/data/data/bird/bird2.jpeg', 0), ...]
    
    print(len(dataset))   # 11
    print(dataset[0][0].size())   # torch.Size([3, 400, 400])
    print(dataset[0][1])   # 0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    2、torchvision.models

    后补…


    3、torchvision.transforms

    3.1 对PIL Image的常见操作

    1)转换为 tensor ToTensor()

    ToTensor() 做了三件事:

    • 把灰度范围从0-255 变换到 0-1之间,其将每一个像素值归一化到 [0,1],其归一化方法比较简单,直接除以255即可
    • 将 nump.ndarray 或 PIL.Image 转为 tensor,数据类型为 torch.FloatTensor
    • 将shape 由 (H,W, C) 转为shape为 (C, H, W)

    2)中心裁剪 CenterCrop()

    torchvision.transforms.CenterCrop(size)   # 所需裁剪的图片尺寸
    
    • 1
    from PIL import Image
    import matplotlib.pyplot as plt
    import torchvision.transforms as transforms
    
    img_src = Image.open('./bird.jpg')
    
    img_1 = transforms.CenterCrop(200)(img_src)
    img_2 = transforms.CenterCrop((200, 200))(img_src)
    img_3 = transforms.CenterCrop((300, 200))(img_src)
    img_4 = transforms.CenterCrop((500, 500))(img_src)
    
    plt.subplot(231)
    plt.imshow(img_src)
    plt.subplot(232)
    plt.imshow(img_1)
    plt.subplot(233)
    plt.imshow(img_2)
    plt.subplot(234)
    plt.imshow(img_3)
    plt.subplot(235)
    plt.imshow(img_4)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    在这里插入图片描述

    以上例子我们可知:
    (1)如果切正方形,transforms.CenterCrop(100) 和 transforms.CenterCrop((100, 100)),两种写size的方法,效果一样
    (2)如果设置的输出的图片尺寸大于原尺寸,会在边上补黑色


    3)随机裁剪 RandomCrop()

    # 依据给定的size随机裁剪
    torchvision.transforms.RandomCrop(size, 
                          padding = None, 
                          pad_if_needed = False, 
                          fill=0, 
                          padding_mode ='constant')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    功能:
    从图片中随机裁剪出尺寸为 size 的图片,如果有 padding,那么先进行 padding,再随机裁剪 size 大小的图片。

    参数:

    • size :所需裁剪的图片尺寸
    • padding: 设置填充大小
      • 当为 a 时,上下左右均填充 a 个像素
      • 当为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素
      • 当为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
    • pad_if_needed:当图片小于设置的 size,是否填
    • padding_mode
      • constant: 像素值由 fill 设定 (默认)
      • edge: 像素值由图像边缘像素设定
      • reflect: 镜像填充,最后一个像素不镜像。([1,2,3,4] -> [3,2,1,2,3,4,3,2])
      • symmetric: 镜像填充,最后一个像素也镜像。([1,2,3,4] -> [2,1,1,2,3,4,4,4,3])
    • fill:当 padding_mode 为 constant 时,设置填充的像素值 (默认为0)

    4)其他更多图像变换操作

    其他更多的图像变换操作,看这里吧


    3.2 对 Tensor 的常见操作

    1)归一化 Normalize()

    作用: 用均值和标准差对张量图像进行归一化,
    公式: i m a g e = ( i m a g e − m e a n ) / s t d image = (image-mean) / std image=(imagemean)/std

    比如,原像素值的取值区间为 [0, 1],在使用 transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) 进行归一化后,原像素值被分布到了 [-1, 1] 区间:

    • 原来的 0~1 最小值 0 则变成 (0 - 0.5) / 0.5 = -1
    • 最大值1则变成 (1 - 0.5) / 0.5 = 1

    其中 mean 和 std 的3个值分表表示图像的3个通道
    如果是单通道的灰度图,可以写成 transforms.Normalize(mean=[0.5], std=[0.5])

    我们可能会看到很多代码里面是这样的:
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    这一组值是怎么来的呢?答案就是通过数据集,提前抽样计算出来的


    2)转换为图像 ToPILImage()

    将 Tensor类型数据转换为图片数据 PILImage, torchvision.transforms.ToPILImage() 函数的作用是把Tensor数据变为完整原始的图片数据(保存后可以直接双击打开的那种)
    其内部处理过程为:

    • 将Tensor的每个元素乘以255
    • 将数据由Tensor转化成Uint8
    • 将Tensor转化成numpy的ndarray类型
    • 对ndarray对象做permute (1, 2, 0)的转置,将shape 由 (C, H, W) 转为shape为(H,W, C)
    • 将ndarray对象转化成PILImage数据格式
    • 输出该PILImage数据(save后可以直接打开)

    4、torchvision.utils

    4.1 图像拼接 grid

    一行最多展示8张图片

    import torch
    import torchvision
    from torchvision import transforms, utils
    from torch.utils import data
    import matplotlib.pyplot as plt
    
    trans = transforms.Compose([transforms.RandomCrop(400), transforms.ToTensor()])
    dataset = torchvision.datasets.ImageFolder('./data', transform=trans)
    train_loader = data.DataLoader(dataset, batch_size=2, shuffle=True)
    
    for (img, label) in train_loader:
        grid = utils.make_grid(img)
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()
        break
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    在这里插入图片描述


    4.2 tensor存储为图片 save_img

    torchvision.utils.save_img(img, path)
    
    • 1

    image 的数据类型是tensor

  • 相关阅读:
    git 相关命令
    【单片机毕业设计】【mcuclub-hj-003】基于单片机的温湿度控制的设计
    13. Python数据类型之布尔类型
    提升你的Android开发技能:从AR/VR沉浸到UI设计和故障排除
    【计算机网络】UDP/TCP 协议
    【40】理解内存(上):虚拟内存和内存保护是什么?★★★★★
    角色授权 CSP 202206-3
    golang 实现四层负载均衡
    Java—多态
    开源模型 Zephyr-7B 发布——跨越三大洲的合作
  • 原文地址:https://blog.csdn.net/weixin_37804469/article/details/126348266