• SENet网络模型


    前言

    论文地址:https://arxiv.org/abs/1709.01507
    论文代码:https://github.com/hujie-frank/SENet

    Squeeze-and-Excitation Networks(SENet)是 一款轻量级的网络,容易扩展到其他结构中去。该网络取得最后一届 ImageNet 2017 竞赛 Image Classification 任务的冠军,在ImageNet数据集上将top-5 error降低到2.251%,原先的最好成绩是2.991%。比2016年的第一名还要低25%。

    SENet的创新点在于关注channel之间的关系自动学习不同channel特征的重要程度

    ✳✳✳SE模块主要为了提升模型对channel特征的敏感性,这个模块是轻量级的,而且可以应用在现有的网络结构中,只需要增加较少的计算量就可以带来性能的提升。

    网络结构

    Squeeze-and-Excitation(SENet)
    在这里插入图片描述
    左边为 C ′ × H ′ × W ′ C'×H'×W' C×H×W 的特征图,经过一系列卷积,pooling 操作 Ftr 之后,得到 C × H × W C×H×W C×H×W 大小的特征图。接下来进行一个 Sequeeze and Excitation block。

    Sequeeze:对 C × H × W C×H×W C×H×W 进行 global average pooling,得到 1 × 1 × C 1×1×C 1×1×C大小的特征图,这个特征图可以理解为具有全局感受野。

    Excitation :使用一个全连接神经网络,对 Sequeeze 之后的结果做一个非线性变换。

    特征重标定:使用 Excitation 得到的结果作为权重,乘到输入特征上。

    代码详解

    在这里插入图片描述

    class SELayer(nn.Module):
        def __init__(self, channel, reduction=16):
            super(SELayer, self).__init__()
            # nn.AdaptiveAvgPool2d就是自适应平均池化,指定输出(H,W)
            self.avg_pool = nn.AdaptiveAvgPool2d(1)#或者写成nn.AdaptiveAvgPool2d((1,1))
            self.fc = nn.Sequential(
                # channel // reduction:减少计算量
                nn.Linear(channel, channel // reduction, bias=False),
                nn.ReLU(inplace=True),
                # 变成原来的通道数
                nn.Linear(channel // reduction, channel, bias=False),
                # 将结果值映射到[0,1]的区间
                nn.Sigmoid()
            )
    
        def forward(self, x):
            # (batch)B,(channel)C,H,W
            b, c, _, _ = x.size()
            # 将H,W拼接池化,[B,C,H,W]=>(avg_pool)=>[B,C,1,1]=>view(b, c)=>[B,C]
            y = self.avg_pool(x).view(b, c)
            #print(y.shape)
            # [B,C]=>self.fc(y)=>[B,C/2]=>
            y = self.fc(y).view(b, c, 1, 1)
            # print(y.shape)
            # y.expand_as:[B,C,1,1]==>[B,C,H,W]
            # [B,C,H,W]*[B,C,H,W]
            # 将计算所得权重与原先张量相乘
            return x * y.expand_as(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    net=SELayer(128)
    print(net(torch.randn(128,128,256,256)).shape)
    
    • 1
    • 2

    对于SE-ResNet模型,只需要将SE模块加入到残差单元(应用在残差学习那一部分)就可以

    class SEBottleneck(nn.Module):
            expansion = 4
    
            def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
                super(SEBottleneck, self).__init__()
                self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
                self.bn1 = nn.BatchNorm2d(planes)
                self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                       padding=1, bias=False)
                self.bn2 = nn.BatchNorm2d(planes)
                self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
                self.bn3 = nn.BatchNorm2d(planes * 4)
                self.relu = nn.ReLU(inplace=True)
                self.se = SELayer(planes * 4, reduction)
                self.downsample = downsample
                self.stride = stride
    
            def forward(self, x):
                residual = x
    
                out = self.conv1(x)
                out = self.bn1(out)
                out = self.relu(out)
    
                out = self.conv2(out)
                out = self.bn2(out)
                out = self.relu(out)
    
                out = self.conv3(out)
                out = self.bn3(out)
                out = self.se(out)
    
                if self.downsample is not None:
                    residual = self.downsample(x)
    
                out += residual
                out = self.relu(out)
    
                return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    参考链接
    后ResNet时代:SENet与SKNet
    最后一届ImageNet冠军模型:SENet

  • 相关阅读:
    中英文说明书丨艾美捷细胞失巢凋亡检测试剂盒介绍
    简单理解三极管导通条件(从电压角度考虑)
    知觉的定义
    网络建设 之 React数据管理
    【牛客 - 剑指offer】JZ3 数组中重复的数字 两种思路 Java实现
    2021.03青少年软件编程(Python)等级考试试卷(三级)
    在windows 上安装 openSSH
    spring循环依赖-不仅仅是八股文
    〔025〕Stable Diffusion 之 接口开发 篇
    TCP实战:即时通信-端口转发
  • 原文地址:https://blog.csdn.net/weixin_42888638/article/details/126714558