import torch
from torch import nn
class MobileNetV1(nn.Module):
def __init__(self):
super(MobileNetV1, self).__init__()
def conv_bn(dim_in, dim_out, stride):
return nn.Sequential(
nn.Conv2d(in_channels=dim_in, out_channels=dim_out, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(dim_out),
nn.ReLU(inplace=True)
)
def conv_dw(dim_in, dim_out, stride):
return nn.Sequential(
nn.Conv2d(in_channels=dim_in, out_channels=dim_in, kernel_size=3, stride=stride, padding=1, groups=dim_in, bias=False),
nn.BatchNorm2d(dim_in),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=dim_in, out_channels=dim_out, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(dim_out),
nn.ReLU(inplace=True)
)
self .mobile = nn.Sequential(
conv_bn(dim_in=3, dim_out=32, stride=2),
conv_dw(dim_in=32, dim_out=64, stride=1),
conv_dw(dim_in=64, dim_out=128, stride=2),
conv_dw(dim_in=128, dim_out=128, stride=1),
conv_dw(dim_in=128, dim_out=256, stride=2),
conv_dw(dim_in=256, dim_out=256, stride=1),
conv_dw(dim_in=256, dim_out=512, stride=2),
conv_dw(dim_in=512, dim_out=512, stride=1),
conv_dw(dim_in=512, dim_out=512, stride=1),
conv_dw(dim_in=512, dim_out=512, stride=1),
conv_dw(dim_in=512, dim_out=512, stride=1),
conv_dw(dim_in=512, dim_out=512, stride=1),
conv_dw(dim_in=512, dim_out=1024, stride=2),
conv_dw(dim_in=1024, dim_out=1024, stride=1),
nn.AvgPool2d(7),
)
self.Flatten = nn.Linear(1024, 9)
def forward(self, x):
x = self.mobile(x)
x = x.view(-1, 1024)
x = self.Flatten(x)
return x
if __name__ == "__main__":
mobilenet = MobileNetV1().cuda()
input = torch.randn(1, 3, 224, 224).cuda()
output = mobilenet(input)
print(output.shape)
print(torch.sum(output))
print(f"max : {torch.max(output)}")
print(f"min : {torch.min(output)}")

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63