• 8.Mobilenetv2网络代码实现


    代码如下:

    1. import math
    2. import os
    3. import numpy as np
    4. import torch
    5. import torch.nn as nn
    6. import torch.utils.model_zoo as model_zoo
    7. #1.建立带有bn的卷积网络
    8. def conv_bn(inp, oup, stride):
    9. return nn.Sequential(
    10. nn.Conv2d(inp,oup,3,stride,bias=False),
    11. nn.BatchNorm2d(oup),
    12. nn.ReLU6(inplace=True)
    13. )
    14. #2.建立卷积核是1x1的卷积网络
    15. def conv_1x1_bn(inp, oup):
    16. return nn.Sequential(
    17. nn.Conv2d(inp,oup,1,1,0,bias=False),
    18. nn.BatchNorm2d(oup),
    19. nn.ReLU6(inplace=True)
    20. )
    21. class InvertedResidual(nn.Module):
    22. def __init__(self, inp, oup, stride, expand_ratio):
    23. super(InvertedResidual,self).__init__()
    24. self.stride=stride
    25. assert stride in [1,2]
    26. hidden_dim=round(inp*expand_ratio)
    27. self.use_res_connect=self.stride==1 and inp==oup
    28. if expand_ratio == 1:
    29. self.conv=nn.Sequential(
    30. # --------------------------------------------#
    31. # 进行3x3的逐层卷积,进行跨特征点的特征提取
    32. # --------------------------------------------#
    33. nn.Conv2d(hidden_dim,hidden_dim,3,stride, 1, groups=hidden_dim, bias=False),
    34. nn.BatchNorm2d(hidden_dim),
    35. nn.ReLU6(inplace=True),
    36. # -----------------------------------#
    37. # 利用1x1卷积进行通道数的调整
    38. # -----------------------------------#
    39. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
    40. nn.BatchNorm2d(oup),
    41. )
    42. else:
    43. self.conv=nn.Sequential(
    44. # -----------------------------------#
    45. # 利用1x1卷积进行通道数的上升
    46. # -----------------------------------#
    47. nn.Conv2d(inp,hidden_dim,1,1,0,bias=False),
    48. nn.BatchNorm2d(hidden_dim),
    49. nn.ReLU6(inplace=True),
    50. # --------------------------------------------#
    51. # 进行3x3的逐层卷积,进行跨特征点的特征提取
    52. # --------------------------------------------#
    53. nn.Conv2d(hidden_dim,hidden_dim,3,stride, 1, groups=hidden_dim, bias=False),
    54. nn.BatchNorm2d(hidden_dim),
    55. nn.ReLU6(inplace=True),
    56. # -----------------------------------#
    57. # 利用1x1卷积进行通道数的下降
    58. # -----------------------------------#
    59. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
    60. nn.BatchNorm2d(oup)
    61. )
    62. def forward(self,x):
    63. if self.use_res_connect:
    64. return x+self.conv(x)
    65. else:
    66. return self.conv(x)
    67. #搭建MobileNetV2网络
    68. class MobileNetV2(nn.Module):
    69. def __init__(self, n_class=1000, input_size=224, width_mult=1.):
    70. super(MobileNetV2, self).__init__()
    71. block=InvertedResidual
    72. input_channel=32
    73. last_channel=1280
    74. interverted_residual_setting = [
    75. # t, c, n, s
    76. [1, 16, 1, 1], # 256, 256, 32 -> 256, 256, 16
    77. [6, 24, 2, 2], # 256, 256, 16 -> 128, 128, 24 2
    78. [6, 32, 3, 2], # 128, 128, 24 -> 64, 64, 32 4
    79. [6, 64, 4, 2], # 64, 64, 32 -> 32, 32, 64 7
    80. [6, 96, 3, 1], # 32, 32, 64 -> 32, 32, 96
    81. [6, 160, 3, 2], # 32, 32, 96 -> 16, 16, 160 14
    82. [6, 320, 1, 1], # 16, 16, 160 -> 16, 16, 320
    83. ]
    84. assert input_size % 32 == 0
    85. input_channel = int(input_channel * width_mult)
    86. self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
    87. # 512, 512, 3 -> 256, 256, 32
    88. self.features=[conv_bn(3,input_channel,2)]
    89. for t,c,n,s in interverted_residual_setting:
    90. output_channel=int(c*width_mult)
    91. for i in range(n):
    92. if i==0:
    93. self.features.append(block(input_channel,output_channel,s, expand_ratio=t))
    94. else:
    95. self.features.append(block(input_channel,output_channel,1, expand_ratio=t))
    96. # input_channel修改为该轮的输出层数
    97. input_channel = output_channel
    98. self.features.append(conv_1x1_bn(input_channel, self.last_channel))
    99. self.features=nn.Sequential(*self.features)
    100. self.classifier=nn.Sequential(
    101. nn.Dropout(0.2),
    102. nn.Linear(self.last_channel,n_class)
    103. )
    104. self._initialize_weights()
    105. def forward(self,x):
    106. x=self.features(x)
    107. x=x.mean(3).mean(2)
    108. x=self.classifier(x)
    109. return x
    110. def _initialize_weights(self):
    111. for m in self.modules():
    112. if isinstance(m, nn.Conv2d):
    113. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    114. m.weight.data.normal_(0, math.sqrt(2. / n))
    115. if m.bias is not None:
    116. m.bias.data.zero_()
    117. elif isinstance(m, nn.BatchNorm2d):
    118. m.weight.data.fill_(1)
    119. m.bias.data.zero_()
    120. elif isinstance(m, nn.Linear):
    121. n = m.weight.size(1)
    122. m.weight.data.normal_(0, 0.01)
    123. m.bias.data.zero_()
    124. if __name__ == '__main__':
    125. print("........................................")
    126. #数据集生成
    127. input=torch.randn(1,3,224,224)
    128. print(input.shape)
    129. #MobileNetV2的输出
    130. ss=MobileNetV2()
    131. # print(ss)
    132. output=ss(input)
    133. print(output.shape)

  • 相关阅读:
    【考研复试】计算机专业考研复试英语常见问题五(兴趣爱好/实践经历篇)
    【C# 基础精讲】List 集合的使用
    C++可以这么学------>类和对象(中)
    振弦采集模块的通讯速率和软件握手( UART)
    C#的关于窗体的类库方案 - 开源研究系列文章
    12. 一文快速学懂常用工具——docker 命令
    运维 之 一键部署Tomcat
    《剑指Offer》二叉树全题——一套妙解,保证让你轻松掌握二叉树~
    No module named ‘PyQt5.QtWebEngineWidgets‘kn-----已解决
    Django定时任务之django_apscheduler使用
  • 原文地址:https://blog.csdn.net/weixin_71719718/article/details/133827440