ShuffleNet v2网络结构复现

from torch import nn
from torch.nn import functional
import torch
from torchsummary import summary
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, height, width)
return x
class CBRM(nn.Module):
def __init__(self, c1, c2):
super(CBRM, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(c2),
nn.ReLU(inplace=True),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
def forward(self, x):
return self.maxpool(self.conv(x))
class Shuffle_Block(nn.Module):
def __init__(self, inp, oup, stride):
super(Shuffle_Block, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(nn.Module):
def __init__(self):
super(ShuffleNetV2, self).__init__()
self.MobileNet_01 = nn.Sequential(
CBRM(3, 32),
Shuffle_Block(32, 128, 2),
Shuffle_Block(128, 128, 1),
Shuffle_Block(128, 256, 2),
Shuffle_Block(256, 256, 1),
Shuffle_Block(256, 512, 2),
Shuffle_Block(512, 512, 1),
)
def forward(self, x):
x = self.MobileNet_01(x)
return x
if __name__ == '__main__':
shufflenetv2 = ShuffleNetV2()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
inputs = shufflenetv2.to(device)
summary(inputs, (3, 640, 640), batch_size=1, device="cuda")

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
----------------------------------------------------------------
Layer (type) Output Shape Param
================================================================
Conv2d-1 [1, 32, 320, 320] 864
BatchNorm2d-2 [1, 32, 320, 320] 64
ReLU-3 [1, 32, 320, 320] 0
MaxPool2d-4 [1, 32, 160, 160] 0
CBRM-5 [1, 32, 160, 160] 0
Conv2d-6 [1, 32, 80, 80] 288
BatchNorm2d-7 [1, 32, 80, 80] 64
Conv2d-8 [1, 64, 80, 80] 2,048
BatchNorm2d-9 [1, 64, 80, 80] 128
ReLU-10 [1, 64, 80, 80] 0
Conv2d-11 [1, 64, 160, 160] 2,048
BatchNorm2d-12 [1, 64, 160, 160] 128
ReLU-13 [1, 64, 160, 160] 0
Conv2d-14 [1, 64, 80, 80] 576
BatchNorm2d-15 [1, 64, 80, 80] 128
Conv2d-16 [1, 64, 80, 80] 4,096
BatchNorm2d-17 [1, 64, 80, 80] 128
ReLU-18 [1, 64, 80, 80] 0
Shuffle_Block-19 [1, 128, 80, 80] 0
Conv2d-20 [1, 64, 80, 80] 4,096
BatchNorm2d-21 [1, 64, 80, 80] 128
ReLU-22 [1, 64, 80, 80] 0
Conv2d-23 [1, 64, 80, 80] 576
BatchNorm2d-24 [1, 64, 80, 80] 128
Conv2d-25 [1, 64, 80, 80] 4,096
BatchNorm2d-26 [1, 64, 80, 80] 128
ReLU-27 [1, 64, 80, 80] 0
Shuffle_Block-28 [1, 128, 80, 80] 0
Conv2d-29 [1, 128, 40, 40] 1,152
BatchNorm2d-30 [1, 128, 40, 40] 256
Conv2d-31 [1, 128, 40, 40] 16,384
BatchNorm2d-32 [1, 128, 40, 40] 256
ReLU-33 [1, 128, 40, 40] 0
Conv2d-34 [1, 128, 80, 80] 16,384
BatchNorm2d-35 [1, 128, 80, 80] 256
ReLU-36 [1, 128, 80, 80] 0
Conv2d-37 [1, 128, 40, 40] 1,152
BatchNorm2d-38 [1, 128, 40, 40] 256
Conv2d-39 [1, 128, 40, 40] 16,384
BatchNorm2d-40 [1, 128, 40, 40] 256
ReLU-41 [1, 128, 40, 40] 0
Shuffle_Block-42 [1, 256, 40, 40] 0
Conv2d-43 [1, 128, 40, 40] 16,384
BatchNorm2d-44 [1, 128, 40, 40] 256
ReLU-45 [1, 128, 40, 40] 0
Conv2d-46 [1, 128, 40, 40] 1,152
BatchNorm2d-47 [1, 128, 40, 40] 256
Conv2d-48 [1, 128, 40, 40] 16,384
BatchNorm2d-49 [1, 128, 40, 40] 256
ReLU-50 [1, 128, 40, 40] 0
Shuffle_Block-51 [1, 256, 40, 40] 0
Conv2d-52 [1, 256, 20, 20] 2,304
BatchNorm2d-53 [1, 256, 20, 20] 512
Conv2d-54 [1, 256, 20, 20] 65,536
BatchNorm2d-55 [1, 256, 20, 20] 512
ReLU-56 [1, 256, 20, 20] 0
Conv2d-57 [1, 256, 40, 40] 65,536
BatchNorm2d-58 [1, 256, 40, 40] 512
ReLU-59 [1, 256, 40, 40] 0
Conv2d-60 [1, 256, 20, 20] 2,304
BatchNorm2d-61 [1, 256, 20, 20] 512
Conv2d-62 [1, 256, 20, 20] 65,536
BatchNorm2d-63 [1, 256, 20, 20] 512
ReLU-64 [1, 256, 20, 20] 0
Shuffle_Block-65 [1, 512, 20, 20] 0
Conv2d-66 [1, 256, 20, 20] 65,536
BatchNorm2d-67 [1, 256, 20, 20] 512
ReLU-68 [1, 256, 20, 20] 0
Conv2d-69 [1, 256, 20, 20] 2,304
BatchNorm2d-70 [1, 256, 20, 20] 512
Conv2d-71 [1, 256, 20, 20] 65,536
BatchNorm2d-72 [1, 256, 20, 20] 512
ReLU-73 [1, 256, 20, 20] 0
Shuffle_Block-74 [1, 512, 20, 20] 0
================================================================
Total params: 445,824
Trainable params: 445,824
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.69
Forward/backward pass size (MB): 270.31
Params size (MB): 1.70
Estimated Total Size (MB): 276.70
----------------------------------------------------------------

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
ShuffleNetV2(
(MobileNet_01): Sequential(
(0): CBRM(
(conv): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
(1): Shuffle_Block(
(branch1): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(2): Shuffle_Block(
(branch2): Sequential(
(0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(3): Shuffle_Block(
(branch1): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(4): Shuffle_Block(
(branch2): Sequential(
(0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(5): Shuffle_Block(
(branch1): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(6): Shuffle_Block(
(branch2): Sequential(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
)
)

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106