• 经典网络解析(四) ResNet | 残差模块,网络结构代码实现全解析


    1 设计初衷

    我们之前讲了VGG等网络,在之前网络的研究中,研究者感觉

    网络越深,分类准确率越高,但是随着网络的加深,科学家们发现分类准确率反而会下降,无论是在训练集上还是测试集上。

    ResNet的作者团队发现了这种现象的真正原因是:

    训练过程中网络的正、反向的信息流动不顺畅,网络没有被充分训练,他们称之为“退化”

    2.网络结构

    2.1 残差块

    解决方式:

    ​ 构建了残差模块,通过堆叠残差模块,可以构建任意深度的神经网络,而不会出现退化的现象

    ​ 提出了批归一化对抗梯度消失,该方法降低了网络训练过程中对于权重初始化的依赖

    H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x

    我们网络要学习的是F(x)

    F ( x ) = H ( x ) − x F(x)=H(x)-x F(x)=H(x)x

    F(x)实际上就是输出与输入的差异,所以叫做残差模块

    在这里插入图片描述

    2.2 中间的卷积网络特征提取块

    中间的块有两种可能

    1 两层3×3卷积层

    在这里插入图片描述

    2 先1×1卷积层,再3×3卷积层,再3×3卷积层

    第一个用了1×1把卷积通道降下去(减少运算量),第二个用了1×1把卷积通道再升上去(便于和输入x连接)

    在这里插入图片描述

    2.3 结构总览表格

    表中展示了18层,34层,50层,101层,152层的ResNet的结构

    [ ]方括号中即是我们上面讲的两个特征提取块,×几代表堆叠几个

    最后经过一个全局平均池化,一个全连接层

    在这里插入图片描述

    3 为什么残差模块有效?

    3.1 前向传播

    1 前向传播过程中重要信息不消失

    通过残差网络的设计,我们可以理解为原来的有机会信息可以维持不变,对分类有帮助的信息得到加强

    能够避免卷积层堆叠存在的信息丢失

    3.2 反向传播

    2 反向传播中梯度可以控制不消失

    在典型的残差模块中,输入数据被分为两部分,一部分通过一个或多个神经层进行变换,而另一部分直接传递给输出。这个直接传递的部分是输入数据的恒等映射,即没有变换。这意味着至少一部分的信息在经过神经网络之后保持不变。

    当进行反向传播以更新神经网络参数时,梯度是根据损失函数计算的。在传统的深度神经网络中,由于多层的网络梯度相乘,梯度可以逐渐变小并导致梯度消失。但在残差模块中,由于存在恒等映射,至少一部分梯度可以直接通过跳过变换的路径传播,而不会受到变换的影响。

    H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x

    比如这个式子对x求偏导 ∂ F / ∂ x + 1 ∂F/∂x+1 F/x+1

    这时候保证了梯度至少会加1 不让梯度连乘逐渐变小

    3.3 恒等映射

    3,可以理解为当网络变深之后,非线性变得很强,网络很难学会简单的恒等映射,残差模块可以解决这个问题

    3.4 集成模型

    4 残差网络可以看做一种集成模型

    可以看做很多简单或复杂的子网络的组合求和!!!

    在这里插入图片描述

    但是这样可能会造成冗余,因为其中还可能会有很多不需要的信息,这便是后来的DenseNet,会让速度提升

    4.代码实现

    实现的是

    import torch
    from torch import nn
    from torch.nn import functional as F
    
    class Residual_block(nn.Module):
        def __init__(self,input_channels,output_channels,first=False):
            super().__init__()
            self.first=first
            if first==True:
                self.conv1=nn.Conv2d(input_channels,output_channels,stride=2,kernel_size=3,padding=1)
                self.conv3=nn.Conv2d(input_channels,output_channels,kernel_size=1,stride=2)
            else:
                self.conv1=nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1)
            self.bn1=nn.BatchNorm2d(output_channels)
            self.conv2=nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1)
            self.bn2=nn.BatchNorm2d(output_channels)
            
        def forward(self,x):
            Y=F.relu(self.bn1(self.conv1(x)))
            Y=self.bn2(self.conv2(Y))
            if self.first==True:
                x=self.conv3(x)
            Y=x+Y
            return F.relu(Y)
    def resnet_block(input_channels,output_channels,num_residual_block,special=False):
        blk=[]
        for i in range(num_residual_block):
            if i==0 and special==True:
                blk.append(Residual_block(input_channels,input_channels))
            if i==0 and special==False:
                blk.append(Residual_block(input_channels,output_channels,first=True))
            else:
                blk.append(Residual_block(output_channels,output_channels))
        return blk
    
    
    b1=nn.Sequential(
        nn.Conv2d(kernel_size=7,in_channels=3,out_channels=64,stride=2,padding=3),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
    )
    R1=Residual_block(64,64)
    x=torch.ones(1,3,224,224)
    for layer in b1:
        x=layer(x)
    
    b18_2=nn.Sequential(*resnet_block(64,64,2,special=True))
    b18_3=nn.Sequential(*resnet_block(64,128,2))
    b18_4=nn.Sequential(*resnet_block(128,256,2))
    b18_5=nn.Sequential(*resnet_block(256,512,2))
    
    b34_2=nn.Sequential(*resnet_block(64,64,3,special=True))
    b34_3=nn.Sequential(*resnet_block(64,128,4))
    b34_4=nn.Sequential(*resnet_block(128,256,6))
    b34_5=nn.Sequential(*resnet_block(256,512,3))
    
    
    #Resnet-18
    Resnet_18=nn.Sequential(b1,b18_2,b18_3,b18_4,b18_5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(512,10))
    #Resnet-34
    Resnet_18=nn.Sequential(b1,b34_2,b34_3,b34_4,b34_5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(512,10))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
  • 相关阅读:
    深度学习实战07-卷积神经网络(Xception)实现动物识别
    Github标星35K+超火的Spring Boot实战项目,附超全教程文档
    31、CSS进阶——@规则
    Launcher3介绍
    【华为机试真题 JAVA】服务器广播-200
    【MySQL高级篇】一文带你吃透数据库的约束
    <Linux开发>驱动开发 -之-基于pinctrl/gpio子系统的beep驱动
    手机联系人恢复:3个方法的选择和比较
    思考如何完成一个审批流
    华为海思校园招聘-芯片-数字 IC 方向 题目分享——第四套
  • 原文地址:https://blog.csdn.net/Q52099999/article/details/133281650