代码如下:
- import math
- import os
- import numpy as np
-
- import torch
- import torch.nn as nn
- import torch.utils.model_zoo as model_zoo
-
- #1.建立带有bn的卷积网络
- def conv_bn(inp, oup, stride):
- return nn.Sequential(
- nn.Conv2d(inp,oup,3,stride,bias=False),
- nn.BatchNorm2d(oup),
- nn.ReLU6(inplace=True)
- )
-
- #2.建立卷积核是1x1的卷积网络
- def conv_1x1_bn(inp, oup):
- return nn.Sequential(
- nn.Conv2d(inp,oup,1,1,0,bias=False),
- nn.BatchNorm2d(oup),
- nn.ReLU6(inplace=True)
- )
-
-
- class InvertedResidual(nn.Module):
- def __init__(self, inp, oup, stride, expand_ratio):
- super(InvertedResidual,self).__init__()
- self.stride=stride
- assert stride in [1,2]
-
- hidden_dim=round(inp*expand_ratio)
- self.use_res_connect=self.stride==1 and inp==oup
-
- if expand_ratio == 1:
-
- self.conv=nn.Sequential(
- # --------------------------------------------#
- # 进行3x3的逐层卷积,进行跨特征点的特征提取
- # --------------------------------------------#
- nn.Conv2d(hidden_dim,hidden_dim,3,stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.ReLU6(inplace=True),
- # -----------------------------------#
- # 利用1x1卷积进行通道数的调整
- # -----------------------------------#
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- else:
- self.conv=nn.Sequential(
- # -----------------------------------#
- # 利用1x1卷积进行通道数的上升
- # -----------------------------------#
- nn.Conv2d(inp,hidden_dim,1,1,0,bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.ReLU6(inplace=True),
- # --------------------------------------------#
- # 进行3x3的逐层卷积,进行跨特征点的特征提取
- # --------------------------------------------#
- nn.Conv2d(hidden_dim,hidden_dim,3,stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.ReLU6(inplace=True),
- # -----------------------------------#
- # 利用1x1卷积进行通道数的下降
- # -----------------------------------#
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup)
- )
- def forward(self,x):
- if self.use_res_connect:
- return x+self.conv(x)
- else:
- return self.conv(x)
-
- #搭建MobileNetV2网络
- class MobileNetV2(nn.Module):
- def __init__(self, n_class=1000, input_size=224, width_mult=1.):
- super(MobileNetV2, self).__init__()
- block=InvertedResidual
- input_channel=32
- last_channel=1280
- interverted_residual_setting = [
- # t, c, n, s
- [1, 16, 1, 1], # 256, 256, 32 -> 256, 256, 16
- [6, 24, 2, 2], # 256, 256, 16 -> 128, 128, 24 2
- [6, 32, 3, 2], # 128, 128, 24 -> 64, 64, 32 4
- [6, 64, 4, 2], # 64, 64, 32 -> 32, 32, 64 7
- [6, 96, 3, 1], # 32, 32, 64 -> 32, 32, 96
- [6, 160, 3, 2], # 32, 32, 96 -> 16, 16, 160 14
- [6, 320, 1, 1], # 16, 16, 160 -> 16, 16, 320
- ]
- assert input_size % 32 == 0
- input_channel = int(input_channel * width_mult)
- self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
- # 512, 512, 3 -> 256, 256, 32
- self.features=[conv_bn(3,input_channel,2)]
-
- for t,c,n,s in interverted_residual_setting:
- output_channel=int(c*width_mult)
- for i in range(n):
- if i==0:
- self.features.append(block(input_channel,output_channel,s, expand_ratio=t))
- else:
- self.features.append(block(input_channel,output_channel,1, expand_ratio=t))
- # input_channel修改为该轮的输出层数
- input_channel = output_channel
- self.features.append(conv_1x1_bn(input_channel, self.last_channel))
- self.features=nn.Sequential(*self.features)
-
- self.classifier=nn.Sequential(
- nn.Dropout(0.2),
- nn.Linear(self.last_channel,n_class)
- )
- self._initialize_weights()
-
- def forward(self,x):
- x=self.features(x)
- x=x.mean(3).mean(2)
- x=self.classifier(x)
- return x
-
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- n = m.weight.size(1)
- m.weight.data.normal_(0, 0.01)
- m.bias.data.zero_()
-
-
-
-
- if __name__ == '__main__':
- print("........................................")
- #数据集生成
- input=torch.randn(1,3,224,224)
- print(input.shape)
-
-
- #MobileNetV2的输出
- ss=MobileNetV2()
- # print(ss)
- output=ss(input)
- print(output.shape)