• Pytorch框架学习记录8——最大池化的使用


    Pytorch框架学习记录8——最大池化的使用

    1. MaxPool2d介绍

    torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

    在由多个输入平面组成的输入信号上应用 2D 最大池化。

    参数

    • kernel_size – 最大的窗口大小
    • stride——窗口的步幅。默认值为kernel_size
    • padding – 要在两边添加隐式零填充
    • dilation – 控制窗口中元素步幅的参数
    • return_indices - 如果True,将返回最大索引以及输出。torch.nn.MaxUnpool2d以后有用
    • ceil_mode – 当为 True 时,将使用ceil而不是floor来计算输出形状

    ceil_mode=True 时,如果滑动窗口在左侧填充或输入内开始,则允许滑动窗口越界。将在右侧填充区域开始的滑动窗口将被忽略。

    2. 举例

    import torch
    from torch import nn
    
    input = torch.tensor([[1, 2, 0, 3, 1],
                          [0, 1, 2, 3, 1],
                          [1, 2, 1, 0, 0],
                          [5, 2, 3, 1, 1],
                          [2, 1, 0, 1, 1]], dtype=torch.float32)
    
    input = torch.reshape(input, (-1, 1, 5, 5))
    
    
    class Test(nn.Module):
        def __init__(self):
            super(Test, self).__init__()
            self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=1, ceil_mode=True)
    
        def forward(self, input):
            output = self.maxpool1(input)
            return output
    
    
    test = Test()
    output = test(input)
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    从图像上直观展示maxpooling的效果:

    import torch
    from torch import nn
    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
    dataloader = DataLoader(dataset, batch_size=64)
    
    class Test(nn.Module):
        def __init__(self):
            super(Test, self).__init__()
            self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=1, ceil_mode=True)
    
        def forward(self, input):
            output = self.maxpool1(input)
            return output
    
    
    test = Test()
    writer = SummaryWriter('logs')
    step = 0
    for data in dataloader:
        imgs, target = data
        output = test(imgs)
        writer.add_images("maxpool", output, global_step=step)
        step += 1
    
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    在这里插入图片描述

  • 相关阅读:
    cmake 学习 cmake-package(7)
    零售业变革下,数智化供应链系统精细化库存管理,构建企业数字化供应链体系
    基于蚁群优化算法的直流电机模糊PID控制(Matlab实现)
    React useContext
    FastGPT知识库结构讲解
    可变参数函数原理
    线程安全与共享资源
    视频去水印怎么去?3个简单的去水印方法分享
    Selenium4+Python3系列(八) - Cookie、截图、单选框及复选框处理、富文本框、日历控件操作
    极简工作流「GitHub 热点速览」
  • 原文地址:https://blog.csdn.net/qq_45955883/article/details/126068385