• conv2d详解--在数组和图像中的使用


    1、环境要求

    1、需要安装Pytorch依赖
    2、官方文档conv2d
    3、图片需要CIFAR10数据集

    2、原理讲解

    将原始二维数据,通过卷积核进行运算,得到运算结果,具体运算步骤:
    在这里插入图片描述
    在这里插入图片描述
    通过卷积核,覆盖输入数据,将选中的数据进行相乘后再相加,则得到输出数据
    在这里插入图片描述
    反复计算到最后,得到输出结果

    这里只是在卷积核安全覆盖在原始图像上时才进行计算,但也可以继续向四周移动,不是完全覆盖,只要有覆盖即可计算,多出的地方补0即可;
    这里也是左右上下移动都是一格一格移动,也可以每次移动两格;
    上面说的两种情况,是conv2d中的padding参数和stride参数不是默认值的情况

    3、函数要求

    函数原型:
    在这里插入图片描述
    参数要求:
    在这里插入图片描述
    最新官网上面要求输入数据为int就行了,这是针对图片数据,在数组数据中,需要tensor数据类型,详细区别见如下例子

    1. 输入要求是tensor数据类型,并且需要minibatch和输入通道,原始的二维数组没有,需要用reshape进行变换
    2. 卷积核也是相同的要求

    3、例子使用

    3.1、数组

    代码:

    
    import torch
    import torchvision
    import torch.nn.functional as F
    
    
    # 输入数据
    input = torch.tensor([[1,2,0,3,1],
                          [0,1,2,3,1],
                          [1,2,1,0,0],
                          [5,2,3,1,1],
                          [2,1,0,1,1]])
    
    print("原始input shape",input.shape)   # torch.Size([5, 5])
    input = torch.reshape(input,(1,1,5,5))    # 进行格式转换,添加前面两个参数,batchsize=1,channel=1,数据是5*5     torch.Size([1, 1, 5, 5])
    print("torch.shape后的shape",input.shape)
    
    # 卷积核
    kernel = torch.tensor([[1,2,1],
                           [0,1,0],
                           [2,1,0]])
    
    kernel = torch.reshape(kernel,(1,1,3,3))
    
    # 默认卷积使用,padding=0,stride=1
    output1 = F.conv2d(input,kernel)
    print("默认卷积",output1)
    
    # padding = 1,stride = 1
    output2 = F.conv2d(input,kernel,padding = 1,stride = 1)
    print("padding = 1,stride = 1",output2)
    
    # padding =1,stride = 2
    output3 = F.conv2d(input,kernel,padding =1,stride = 2)
    print("padding =1,stride = 2",output3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35

    输出:
    在这里插入图片描述

    3.2、图片

    代码:

    import torch
    from torch import nn
    from torch.nn import Conv2d
    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    
    # download=False,我这里的数据集已经下载好了,就不用每次运行的时候都下载一次,可以在第一次的时候,改为True进行下载
    # "./datasetvision"为存放路径
    # transform=torchvision.transforms.ToTensor() 将图片数据使用torchvision进行格式转换
    dataset = torchvision.datasets.CIFAR10("./datasetvision",train=False,
                                           transform=torchvision.transforms.ToTensor(),download=False)
    
    # 数据预处理,batch_size=64表明每次获取的数据个数为64张
    dataloader = DataLoader(dataset,batch_size=64)
    
    
    # 简单神经网络定义
    class ConNet(nn.Module):
        def __init__(self):
            super(ConNet, self).__init__()
            # 输入通道 因为是彩色图像RGB 所以输入通道是3层,输出6层,卷积层是3*3
            self.conv2d = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
    
        #定义具体函数体
        def forward(self,x):
            result = self.conv2d(x)
            return result
    
    
    Work = ConNet()
    print(Work)  # 打印一下神经网络结构:  ConNet((conv2d): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)))
    
    # 使用tensorboard进行文件夹命名
    write = SummaryWriter("logsConv2d")
    
    # data是dataloader中的元组
    step = 0
    for data in dataloader:
        imgs,target = data
    
        # print("原始图像",imgs.shape)  #前后差异
        # print(output.shape)
        write.add_images("input",imgs,step)  # 将初始图像放入tensorboard进行对比
    
        output = Work(imgs)  # 进行图像卷积
        output = torch.reshape(output,(-1,3,30,30))   # 这里因为卷积的时候,将输出通道定义为6个通道,board不知道如何展示,所以使用reshape进行转换
        write.add_images("output",output,step)
        step = step+1
    
    write.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52

    结果:
    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    git常用命令
    Servlet —— Tomcat, 初学 Servlet 程序
    【实战】Kubernetes安装持久化工具NFS-StorageClass
    类和对象(末)
    chapter 11 in C primer plus
    MySQL之事务
    yolov5+bytetrack算法在华为NPU上进行端到端开发
    Java面试题:Java中垃圾回收机制是如何工作的?请描述几种常见的垃圾回收算法
    【DaVinci Developer工具实战】02 - 软件设计编辑器
    CSS中如何在table中隐藏表格中从第4个开始的多个 <tr> 元素
  • 原文地址:https://blog.csdn.net/qq_44864833/article/details/125510091