• torchvision.transforms 数据预处理:Normalize()


    1、Normalize() 的作用

    Normalize() 是pytorch中的数据预处理函数,包含在 torchvision.transforms 模块下。一般用于处理图像数据,其输入数据格式是 torch.Tensor,而不是 np.array。

    1.1 Normalize() 的源码

    看一下 Normalize() 函数的源码:

    class Normalize(torch.nn.Module):
        """Normalize a tensor image with mean and standard deviation.
        This transform does not support PIL Image.
        Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
        channels, this transform will normalize each channel of the input
        ``torch.*Tensor`` i.e.,
        ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
    
        .. note::
            This transform acts out of place, i.e., it does not mutate the input tensor.
    
        Args:
            mean (sequence): Sequence of means for each channel.
            std (sequence): Sequence of standard deviations for each channel.
            inplace(bool,optional): Bool to make this operation in-place.
    
        """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    大意是:使用均值和标准差对输入的tensor的每个通道进行标准化,计算公式是:

    output[channel] = (input[channel] - mean[channel]) / std[channel]
    
    • 1

    这里要与正态分布标准化进行区分,将一个正态分布转化为标准正太分布(即高斯分布)的公式为 Z=(X-mean)/var,这里的分母是方差而不是标准差。

    1.2 代码示例

    这里用代码来演示一下Normalize()的作用:

    import numpy as np
    from torchvision import transforms
    
    data = np.array([
        [0., 5, 10, 20, 0],
        [255, 125, 180, 255, 196]
    ])    # 因为 Normalize() 的输入必须是 float 类型,所以这里定义一个 np.float64类型的 array
    tensor = transforms.ToTensor()(data)
    norm = transforms.Normalize((0.5), (0.5))   # mean=0.5   std=0.5
    
    print(f"tensor = {tensor}")
    print(f"norm(tensor) = {norm(tensor)}")
    
    """
    tensor = tensor([[[  0.,   5.,  10.,  20.,   0.],
             [255., 125., 180., 255., 196.]]], dtype=torch.float64)
    norm(tensor) = tensor([[[ -1.,   9.,  19.,  39.,  -1.],
             [509., 249., 359., 509., 391.]]], dtype=torch.float64)
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    很容易可以验证:

    (0 - 0.5) / 0.5 = -1
    (5 - 0.5) / 0.5 = 9
    (255 - 0.5) / 0.5 = 509
    
    • 1
    • 2
    • 3

    2、ToTensor() 和 Normalize() 的结合使用

    在图像预处理中,Normalize() 通常和 ToTensor() 一起使用。 ToTensor() 的介绍可以参考 torchvision.transforms 数据预处理:ToTensor()

    首先 ToTensor() 将 [0,255] 的像素值归一化为 [0,1],然后使用 Normalize(0.5, 0.5) 将 [0,1] 进行标准化为 [-1,1]

    ToTensor() 和Normalize() 结合使用的代码示例:

    import numpy as np
    from torchvision import transforms
    
    data = np.array([
        [0, 5, 10, 20, 0],
        [255, 125, 180, 255, 196]
    ], dtype=np.uint8)
    tensor = transforms.ToTensor()(data)
    norm = transforms.Normalize(0.5, 0.5)
    
    print(f"tensor = {tensor}")
    print(f"norm(tensor) = {norm(tensor)}")
    
    """
    tensor = tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],
             [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])
    norm(tensor) = tensor([[[-1.0000, -0.9608, -0.9216, -0.8431, -1.0000],
             [ 1.0000, -0.0196,  0.4118,  1.0000,  0.5373]]])
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    使用 transforms.Compose() 函数进行图像预处理:

    from torchvision import transforms
    import cv2
    
    filePath = "Dataset/FFHQ/00000.png"
    img = cv2.imread(filePath)
    
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img = transform(img)
    print(img)
    
    """
    tensor([[[ 0.1451,  0.1294,  0.1059,  ...,  0.2157,  0.2000,  0.1843],
             [ 0.1529,  0.1137,  0.1294,  ...,  0.1843,  0.1843,  0.1922],
             [ 0.1216,  0.1137,  0.1529,  ...,  0.2314,  0.1686,  0.1529],
             ...,
             [-0.8118, -0.7961, -0.7725,  ...,  0.0980,  0.0824,  0.0588],
             [-0.8196, -0.8196, -0.8039,  ...,  0.0588,  0.0353,  0.0275],
             [-0.8667, -0.8510, -0.8275,  ...,  0.0431,  0.0431,  0.0510]]])
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
  • 相关阅读:
    【ARM Trace32(劳特巴赫) 使用介绍 5 -- Trace32 scan dump 详细介绍】
    Python练习题:实现除自身以外元素的乘积
    专利-分析方法总结
    DeferredResult解决了什么问题
    oracle19c单机应用补丁-缺少包导致失败
    Java 中将多个 PDF 文件合并为一个 PDF
    4.DesignForShapes\2.AutoRoutingAddShape
    Mybatis核心配置文件中的常用标签
    前端--CSS
    一个C++工程内存泄漏问题的排查及重现工程
  • 原文地址:https://blog.csdn.net/qq_43799400/article/details/127787393