• 深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别


    Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。

    就是把高纬度的数组按照 x轴或者y轴 进行拉伸,变成一维的数组

    为了更好的理解Flatten层作用,我把这个神经网络进行可视化如下图:(来自网络)

    flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。

    python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。

    比如一个数据的维度是(S_{0},S_{1},S_{2},S_{3},...,S_{n}),flatten(m)后的数据为(S_{0},S_{1},S_{2},S_{3},...,S_{m-2},S_{m-1},S_{m},S_{m+1},S_{m+2},...,S_{n})

     案例程序如下:

    1. import torch
    2. import torchvision
    3. from torch import nn
    4. from torch.nn import Linear
    5. from torch.utils.data import DataLoader
    6. dataset = torchvision.datasets.CIFAR10("./data_CIFAR10", train=False,
    7. transform=torchvision.transforms.ToTensor(),download=True)
    8. dataloader = DataLoader(dataset,batch_size=64)
    9. class Tudui(nn.Module):
    10. def __init__(self):
    11. super(Tudui, self).__init__()
    12. self.linear1 = Linear(196608,10)
    13. def forward(self,input):
    14. output = self.linear1(input)
    15. return output
    16. tudui = Tudui()
    17. for data in dataloader:
    18. imgs, targets = data
    19. print(imgs.shape)
    20. # output = torch.reshape(imgs,(1,1,1,-1))
    21. output = torch.flatten(imgs)
    22. print(output.shape)
    23. output = tudui(output)
    24. print(output.shape)

    运行结果如下:

    从上图可以看出,torch_size([64,3,32,32])是print(imgs.shape)打印得到的结果,表示batch_size=64,channel=3,高H=32,宽W=32

    上面的结果通过flatten后得到的结果维度大小为torch_size([196608]),其中的196608=64*32*32*3得到的

    然后经过神经网络(Tudui)得到的结果维度大小是torch_size([10]),表示输出为10个类别。

    如果把flatten改为reshape会出现什么结果呢,程序如下:

    1. import torch
    2. import torchvision
    3. from torch import nn
    4. from torch.nn import Linear
    5. from torch.utils.data import DataLoader
    6. dataset = torchvision.datasets.CIFAR10("./data_CIFAR10", train=False,
    7. transform=torchvision.transforms.ToTensor(),download=True)
    8. dataloader = DataLoader(dataset,batch_size=64)
    9. class Tudui(nn.Module):
    10. def __init__(self):
    11. super(Tudui, self).__init__()
    12. self.linear1 = Linear(196608,10)
    13. def forward(self,input):
    14. output = self.linear1(input)
    15. return output
    16. tudui = Tudui()
    17. for data in dataloader:
    18. imgs, targets = data
    19. print(imgs.shape)
    20. output = torch.reshape(imgs,(1,1,1,-1))
    21. # output = torch.flatten(imgs)
    22. print(output.shape)
    23. output = tudui(output)
    24. print(output.shape)

    运行结果如下:

    我们发现,经过了reshape后,得到的结果尺寸维度是torch_size([1,1,1,196608]),其结果表示batch_size=1,channel=1,高H=1,宽W=196608

    上面结果通过了神经网络(Tudui)后得到结果尺寸维度为torch_size([1,1,1,10]),表示输出为10个类别。

  • 相关阅读:
    数据库使用psql及jdbc进行远程连接,不定时自动断开的解决办法
    企业网络安全面临哪些困境?可以怎样应对?
    腾讯云 AI 绘画:文生图、图生图、图审图 快速入门
    vue3:获取当前路由地址
    关于尚硅谷禹神Vue视频四十二级v-cloak,delay_server服务器服务器的替代方案
    Linux(进程间通信)
    做题日记 之 pairs(HDU-5178)
    LeetCode Cookbook 链表习题 上篇
    深入了解vue2向vue3变迁过渡的知识点
    java基于ssm的图书销售库存管理入库信息系统
  • 原文地址:https://blog.csdn.net/qq_42233059/article/details/126663501