• 分类神经网络2:ResNet模型复现


    目录

    ResNet网络架构

    ResNet部分实现代码


    ResNet网络架构

    论文原址:https://arxiv.org/pdf/1512.03385.pdf

    残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的,通过引入残差学习解决了深度网络训练中的退化问题,同时该网络结构特点主要表现为3点,超深的网络结构(超过1000层)、提出residual(残差结构)模块和使用Batch Normalization 加速训练(丢弃dropout)。ResNet网络的模型结构如下:

    ResNet网络通过“跳跃连接”,将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。即神经网络学习到该层的参数是冗余的时候,它可以选择直接走这条“跳接”曲线,跳过这个冗余层,而不需要再去拟合参数使得H(x)=F(x)=x。同时通过这种连接方式不仅保护了信息的完整性(避免卷积层堆叠存在的信息丢失),整个网络也只需要学习输入、输出差别的部分,这克服了由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题(梯度消失或梯度爆炸)。

    残差模块如下图示:

    残差结构有两种,常规残差和瓶颈残差

    常规残差:由2个3x3卷积层堆叠而成,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度(随着网络深度的加深,这种残差模块在实践中并不十分有效)。

    瓶颈残差:依次由1x1 、3x3 、1x1个卷积层构成,这里1x1卷积,能够对通道数channel起到升维或降维的作用,从而令3x3 的卷积,以相对较低维度的输入进行卷积运算,提高计算效率。

    ResNet网络的具体配置信息如下:

    在构建神经网络时,首先采用了步长为2的卷积层进行图像尺寸缩减,即下采样操作,紧接着是多个残差结构,在网络架构的末端,引入了一个全局平均池化层,用于整合特征信息,最后是一个包含1000个类别的全连接层,并在该层后应用了softmax激活函数以进行多分类任务。值得注意的是,通过引入残差连接模块,其最深的网络结构达到了152层,同时在50层后均使用的是瓶颈残差结构。

    ResNet部分实现代码

    老样子,直接上代码,建议大家阅读代码时结合网络结构理解

    1. import torch
    2. import torch.nn as nn
    3. __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
    4. class ConvBNReLU(nn.Module):
    5. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    6. super(ConvBNReLU, self).__init__()
    7. self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    8. self.bn = nn.BatchNorm2d(num_features=out_channels)
    9. self.relu = nn.ReLU(inplace=True)
    10. def forward(self, x):
    11. x = self.conv(x)
    12. x = self.bn(x)
    13. x = self.relu(x)
    14. return x
    15. class ConvBN(nn.Module):
    16. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    17. super(ConvBN, self).__init__()
    18. self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
    19. self.bn = nn.BatchNorm2d(num_features=out_channels)
    20. def forward(self, x):
    21. x = self.conv(x)
    22. x = self.bn(x)
    23. return x
    24. def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1):
    25. """3x3 convolution with padding:捕捉局部特征和空间相关性,学习更复杂的特征和抽象表示"""
    26. return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
    27. padding=dilation, groups=groups, bias=False, dilation=dilation)
    28. def conv1x1(in_channels, out_channels, stride=1):
    29. """1x1 convolution:实现降维或升维,调整通道数和执行通道间的线性变换"""
    30. return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
    31. class BasicBlock(nn.Module):
    32. expansion = 1
    33. def __init__(self, in_channels, out_channels, stride=1, downsample=None):
    34. super(BasicBlock, self).__init__()
    35. self.convbnrelu1 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride)
    36. self.convbn1 = ConvBN(out_channels, out_channels, kernel_size=3)
    37. self.relu = nn.ReLU(inplace=True)
    38. self.downsample = downsample
    39. self.stride = stride
    40. self.conv_down = nn.Sequential(
    41. conv1x1(in_channels, out_channels * self.expansion, self.stride),
    42. nn.BatchNorm2d(out_channels * self.expansion),
    43. )
    44. def forward(self, x):
    45. residual = x
    46. out = self.convbnrelu1(x)
    47. out = self.convbn1(out)
    48. if self.downsample:
    49. residual = self.conv_down(x)
    50. out += residual
    51. out = self.relu(out)
    52. return out
    53. class Bottleneck(nn.Module):
    54. expansion = 4
    55. def __init__(self, in_channels, out_channels, stride=1, downsample=None):
    56. super(Bottleneck, self).__init__()
    57. groups = 1
    58. base_width = 64
    59. dilation = 1
    60. width = int(out_channels * (base_width / 64.)) * groups # wide = out_channels
    61. self.conv1 = conv1x1(in_channels, width) # 降维通道数
    62. self.bn1 = nn.BatchNorm2d(width)
    63. self.conv2 = conv3x3(width, width, stride, groups, dilation)
    64. self.bn2 = nn.BatchNorm2d(width)
    65. self.conv3 = conv1x1(width, out_channels * self.expansion) # 升维通道数
    66. self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
    67. self.relu = nn.ReLU(inplace=True)
    68. self.downsample = downsample
    69. self.stride = stride
    70. self.conv_down = nn.Sequential(
    71. conv1x1(in_channels, out_channels * self.expansion, self.stride),
    72. nn.BatchNorm2d(out_channels * self.expansion),
    73. )
    74. def forward(self, x):
    75. residual = x
    76. out = self.conv1(x)
    77. out = self.bn1(out)
    78. out = self.relu(out)
    79. out = self.conv2(out)
    80. out = self.bn2(out)
    81. out = self.relu(out)
    82. out = self.conv3(out)
    83. out = self.bn3(out)
    84. if self.downsample:
    85. residual = self.conv_down(x)
    86. out += residual
    87. out = self.relu(out)
    88. return out
    89. class ResNet(nn.Module):
    90. def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
    91. groups=1, width_per_group=64):
    92. super(ResNet, self).__init__()
    93. self.inplanes = 64
    94. self.dilation = 1
    95. replace_stride_with_dilation = [False, False, False]
    96. self.groups = groups
    97. self.base_width = width_per_group
    98. self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
    99. self.bn1 = nn.BatchNorm2d(self.inplanes)
    100. self.relu = nn.ReLU(inplace=True)
    101. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    102. self.layer1 = self._make_layer(block, 64, layers[0])
    103. self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
    104. dilate=replace_stride_with_dilation[0])
    105. self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
    106. dilate=replace_stride_with_dilation[1])
    107. self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
    108. dilate=replace_stride_with_dilation[2])
    109. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    110. self.fc = nn.Linear(512 * block.expansion, num_classes)
    111. for m in self.modules():
    112. if isinstance(m, nn.Conv2d):
    113. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    114. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
    115. nn.init.constant_(m.weight, 1)
    116. nn.init.constant_(m.bias, 0)
    117. # Zero-initialize the last BN in each residual branch,
    118. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    119. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    120. if zero_init_residual:
    121. for m in self.modules():
    122. if isinstance(m, Bottleneck):
    123. nn.init.constant_(m.bn3.weight, 0)
    124. elif isinstance(m, BasicBlock):
    125. nn.init.constant_(m.bn2.weight, 0)
    126. def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    127. downsample = False
    128. if dilate:
    129. self.dilation *= stride
    130. stride = 1
    131. if stride != 1 or self.inplanes != planes * block.expansion:
    132. downsample = True
    133. layers = nn.ModuleList()
    134. layers.append(block(self.inplanes, planes, stride, downsample))
    135. self.inplanes = planes * block.expansion
    136. for _ in range(1, blocks):
    137. layers.append(block(self.inplanes, planes))
    138. return layers
    139. def forward(self, x):
    140. x = self.conv1(x)
    141. x = self.bn1(x)
    142. x = self.relu(x)
    143. x = self.maxpool(x)
    144. for layer in self.layer1:
    145. x = layer(x)
    146. for layer in self.layer2:
    147. x = layer(x)
    148. for layer in self.layer3:
    149. x = layer(x)
    150. for layer in self.layer4:
    151. x = layer(x)
    152. x = self.avgpool(x)
    153. x = torch.flatten(x, 1)
    154. x = self.fc(x)
    155. return x
    156. def resnet18(num_classes, **kwargs):
    157. return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)
    158. def resnet34(num_classes, **kwargs):
    159. return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, **kwargs)
    160. def resnet50(num_classes, **kwargs):
    161. return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)
    162. def resnet101(num_classes, **kwargs):
    163. return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs)
    164. def resnet152(num_classes, **kwargs):
    165. return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)
    166. if __name__=="__main__":
    167. import torchsummary
    168. device = 'cuda' if torch.cuda.is_available() else 'cpu'
    169. input = torch.ones(2, 3, 224, 224).to(device)
    170. net = resnet50(num_classes=4)
    171. net = net.to(device)
    172. out = net(input)
    173. print(out)
    174. print(out.shape)
    175. torchsummary.summary(net, input_size=(3, 224, 224))
    176. # Total params: 134,285,380

    希望对大家能够有所帮助呀!

  • 相关阅读:
    Mybatis框架_涉及技术与拓展
    【疑难】使用ARM development studio仿真 error:Failed to create Jython interpreter
    ubuntu挂载数据盘,第一次挺顺利
    Go学习之路-环境搭建
    信息化与数字化的区别
    【学习笔记】mac安装maven与idea自带maven
    Asp.net core Web Api 配置swagger中文
    四大常见的排序算法JAVA
    WPF监听快捷键的几种方式
    在OpenWrt中配置使用FTP文件服务
  • 原文地址:https://blog.csdn.net/m0_73228309/article/details/138046422