• torch_vision(一):数据增强和转换模块torchvision.transforms


    torchvision.transforms 学习笔记

    1. torchvision介绍

    The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.

    torchvision包含了很多通用的数据集,模型架构,以及图像转换方法,配合pytorch使用更好搭建训练模型。

    2. TRANSFORMING AND AUGMENTING IMAGES

    图像转换和数据增强方法介绍

    1. 它们可以使用Compose链接在一起。
    2. 大多数转换类都有一个等效的函数:函数转换提供对转换的细粒度控制。
    3. 大多数变换同时接受PIL图像和张量图像,尽管有些变换只接受PIL图像,有些则只接受张量图像。
    4. 可以通过transform模块用于tensor与PIL图像之间的转换。

    3. Resize transform

    为了方便展示,首先定义一个画图函数

    from PIL import Image
    from pathlib import Path
    import matplotlib.pyplot as plt
    import numpy as np
    
    import torch
    import torchvision.transforms as T
    
    
    plt.rcParams["savefig.bbox"] = 'tight'
    orig_img = Image.open(Path('assets') / '24colormap.jpg')
    print(np.array(orig_img).shape)
    # if you change the seed, make sure that the randomly-applied transforms
    # properly show that the image can be both transformed and *not* transformed!
    torch.manual_seed(0)
    
    def plot(imgs):
        num = len(imgs)
        if num > 2:
            num_rows = 2
        else:
            num_rows = 1
    
        
        num_cols = (num + 1) // num_rows
        # print('row, col:', num_rows, num_cols)
        fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
        i = 0
        for row_idx in range(num_rows):
            for col_idx in range(num_cols):
                ax = axs[row_idx, col_idx]
                ax.imshow(np.asarray(imgs[i]))
                ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
                print(i, np.asarray(imgs[i]).shape)
                i += 1
                if i == num:
                    break
                
        plt.tight_layout()
    
    plt.imshow(orig_img)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    原图如下:

    请添加图片描述

    resized_imgs = [T.Resize(size=size)(orig_img) for size in ((30,30), 50, 100, orig_img.size[::-1])]
    plot(resized_imgs)
    
    • 1
    • 2

    结果展示
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HA5YPtwj-1666165798131)(20221019144439.png)]

    4. CenterCrop transform 中心裁剪

    center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 200, orig_img.size[::-1])]
    plot(center_crops)
    
    • 1
    • 2

    结果展示

    5. RandomRotation 随机旋转

    rotater = T.RandomRotation(degrees=(0, 180))
    rotated_imgs = [rotater(orig_img) for _ in range(4)]
    plot(rotated_imgs)
    
    • 1
    • 2
    • 3

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ldQXW6Xs-1666165747915)(20221019144813.png)] 在这里插入图片描述

    6.RandomAffine 随机仿射变换

    affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
    affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
    plot(affine_imgs)
    
    • 1
    • 2
    • 3

    结果展示:
    在这里插入图片描述

    7.RandomPerspective 随机透视变换

    perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
    perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
    plot(perspective_imgs)
    
    • 1
    • 2
    • 3

    结果展示:
    在这里插入图片描述

    8.RandomCrop 随机crop固定尺寸

    cropper = T.RandomCrop(size=(128, 128))
    crops = [cropper(orig_img) for _ in range(4)]
    plot(crops)
    
    • 1
    • 2
    • 3

    结果展示:

    在这里插入图片描述

    9.RandomResizedCrop 随机crop之后,再 resize到固定尺寸

    resize_cropper = T.RandomResizedCrop(size=(32, 32))
    resized_crops = [resize_cropper(orig_img) for _ in range(4)]
    plot(resized_crops)
    
    • 1
    • 2
    • 3

    结果展示:
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bjxzpW8h-1666165747916)(20221019145427.png)]

    10.RandomPerspective 随机透视变换

    perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
    perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
    plot(perspective_imgs)
    
    • 1
    • 2
    • 3

    结果展示:
    在这里插入图片描述

    11. 其他常见操作

    数据变换

    1. 线性变换
    torchvision.transforms.LinearTransformation(transformation_matrix,mean_vector)
    
    • 1
    1. 标准化:减去均值,除以标准差
    torchvision.transforms.Normalize(mean,std,inplace=False)
    
    • 1

    格式转换

    1. 最常用的就是 pil image 或者 np.ndarray 转换为tensor
    torchvision.transforms.ToTensor
    
    • 1

    图像翻转

    1. 随机水平和随机垂直翻转
    torchvision.transforms.RandomHorizontalFlip(p=0.5)
    torchvision.transforms.RandomVerticalFlip(p=0.5)
    
    • 1
    • 2

    更多数据增强方法,请参看
    [1]https://pytorch.org/vision/0.13/transforms.html
    [2]https://zhuanlan.zhihu.com/p/519919904

    12. transforms.Compose() 和 torch.nn.Sequential()

    transforms.Compose() 用于整合一系列的图像变换函数,将图片按照 Compose() 中的顺序依次处理。torch.nn.Sequential() 与 transforms.Compose() 起到相同的功能。torch.nn.Sequential() 可以和 torch.jit.script() 结合来导出模型。

    #Compose
    transform1 = transforms.Compose([
    	transforms.CenterCrop(10),
    	transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    
    #Sequential
    transform2 = torch.nn.Sequential(
    	transforms.CenterCrop(10),
    	transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    )
    scripted_transforms = torch.jit.script(transforms)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    关于torchscript 可以参看文章官网, 官网jit

    13. 函数转换

    除了以上的转换方法,一般都有对应的函数进行数据增强
    比如:

    import torchvision.transforms.functional as TF
    
    TF.adjust_brightness(orig_img, 0.2)
    TF.adjust_contrast(orig_img, 0.6)
    TF.adjust_hue(orig_img, -0.4)
    TF.adjust_saturation(orig_img, 0)
    TF.adjust_sharpness(orig_img, 2)
    TF.affine(orig_img, angle=0,translate=[150,150],scale=1, shear=0)
    TF.crop(orig_img, 300, 300, 500, 600)
    TF.erase(orig_img, 100, 200, 800, 600,0)
    TF.gaussian_blur(orig_img, 21, 5)
    TF.resize(orig_img, [400,800])
    TF.rotate(orig_img, 60)
    TF.vflip(orig_img)
    TF.hflip(orig_img)
    TF.crop(orig_img, 300, 300, 500, 600)
    TF.erase(orig_img, 100, 200, 800, 600,0)
    TF.gaussian_blur(orig_img, 21, 5)
    TF.resize(orig_img, [400,800])
    TF.rotate(orig_img, 60)
    TF.vflip(orig_img)
    TF.hflip(orig_img)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    14. 利用函数定义转换

    转换函数,可以处理多张图像

    import torchvision.transforms.functional as TF
    import random
    
    def my_segmentation_transforms(image, segmentation):
        if random.random() > 0.5:
            angle = random.randint(-30, 30)
            image = TF.rotate(image, angle)
            segmentation = TF.rotate(segmentation, angle)
        # more transforms ...
        return image, segmentation
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    转换类

    import torchvision.transforms.functional as TF
    import random
    
    class MyRotationTransform:
        """Rotate by one of the given angles."""
    
        def __init__(self, angles):
            self.angles = angles
    
        def __call__(self, x):
            angle = random.choice(self.angles)
            return TF.rotate(x, angle)
    
    rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    15. 参考

    https://pytorch.org/vision/0.13/transforms.html#functional-transforms

  • 相关阅读:
    基于多进制准循环稀疏校验矩阵构造方法的LDPC编译码实现
    腾讯云对象存储cors错误处理
    CentOS7 rabbitmq3.8 与 erlang22. 安装、干净卸载
    pycharm 中package, directory, sources root, resources root的区别
    算法与数据结构- 顺序表的实现
    老板叫我把几十万条Excel数据录入系统
    137. 只出现一次的数字 II
    spring中那些让你爱不释手的代码技巧(续集)
    看完这篇,你的API服务设计能力将再次进化!
    几分钟就搞定网站速度慢、网站卡等问题
  • 原文地址:https://blog.csdn.net/tywwwww/article/details/127409638