• GoogLeNet网络


    目录

    1. 创新点

    1.1 引入Inception结构

    1.2 1×1卷积降维

    1.3 两个辅助分类器 

    1.4 丢弃全连接层,使用平均池化层

    2. 网络结构

    3. 知识点

    3.1 torch.cat

    3.2 关于self.training

    3.3 关于load_state_dict中的strict

    4. 代码 

    4.1 model.py

    4.2 train.py

    4.3 predict.py

    5. 结果


    1. 创新点

    1.1 引入Inception结构

    作用:融合不同尺度的特征信息

    注意:每个分支所得特征矩阵的宽、高必须相同

    下图来自:Going deeper with convolutions

    1.2 1×1卷积降维

    channels: 512

    a.不使用1×1卷积核降维

    使用:64个5×5卷积核进行卷积

    参数:5×5×512×64=819,200

    b.使用1×1卷积核降维

    使用:24个1×1卷积核进行卷积

    1.3 两个辅助分类器 

    内容:GoogLeNet有三个输出层(两个为辅助分类层)

     Going deeper with convolutions文章里:

    • An average pooling layer with 5×5 filter size and stride 3, resulting in an 4×4×512 output for the (4a), and 4×4×528 for the (4d) stage.
    • A 1×1 convolution with 128 filters for dimension reduction and rectified linear activation.
    • A fully connected layer with 1024 units and rectified linear activation.
    • A dropout layer with 70% ratio of dropped outputs.
    • A linear layer with softmax loss as the classifier (predicting the same 1000 classes as the main classifier, but removed at inference time).

    1.4 丢弃全连接层,使用平均池化层

    作用:大大减少模型的参数

    2. 网络结构

    Inception层太多,列出几个:

    3. 知识点

    3.1 torch.cat

    1. import torch
    2. a = torch.Tensor([1, 2, 3])
    3. b = torch.Tensor([4, 5, 6])
    4. c = [a, b]
    5. print(torch.cat(c))
    6. # tensor([1., 2., 3., 4., 5., 6.])

    3.2 关于self.training

    使用model.train()和model.eval()控制模型的状态

    在model.train()模式下self.training=True

    在model.eval()模式下self.training=False

    3.3 关于load_state_dict中的strict

    为True:有什么要什么,每一个键都有(默认为True)

    为False:有什么要什么,没有的就不要

    missing_keys和unexpected_keys:缺失的、不期望的键

    4. 代码 

    4.1 model.py

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. class GoogLeNet(nn.Module):
    5. def __init__(self, num_classes=1000, aux_use=True, init_weight=False):
    6. super(GoogLeNet, self).__init__()
    7. self.aux_use = aux_use
    8. self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
    9. self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) # ceil_mode默认向下取整 True为向上取整
    10. self.conv2 = BasicConv2d(64, 64, kernel_size=1)
    11. self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
    12. self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
    13. self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
    14. self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
    15. self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
    16. self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
    17. self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
    18. self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
    19. self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
    20. self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
    21. self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
    22. self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
    23. self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
    24. if self.aux_use:
    25. self.aux1 = InceptionAux(512, num_classes)
    26. self.aux2 = InceptionAux(528, num_classes)
    27. self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # 自适应平均池化 指定输出(H,W)
    28. self.dropout = nn.Dropout(0.4)
    29. self.fc = nn.Linear(1024, num_classes)
    30. if init_weight:
    31. self._initialize_weights_()
    32. def forward(self, x):
    33. # N×3×224×224
    34. x = self.conv1(x)
    35. x = self.maxpool1(x)
    36. x = self.conv2(x)
    37. x = self.conv3(x)
    38. x = self.maxpool2(x)
    39. x = self.inception3a(x)
    40. x = self.inception3b(x)
    41. x = self.maxpool3(x)
    42. x = self.inception4a(x)
    43. if self.training and self.aux_use:
    44. aux1 = self.aux1(x)
    45. x = self.inception4b(x)
    46. x = self.inception4c(x)
    47. x = self.inception4d(x)
    48. if self.training and self.aux_use:
    49. aux2 = self.aux2(x)
    50. x = self.inception4e(x)
    51. x = self.maxpool4(x)
    52. x = self.inception5a(x)
    53. x = self.inception5b(x)
    54. x = self.avgpool(x)
    55. x = torch.flatten(x, start_dim=1)
    56. x = self.dropout(x)
    57. x = self.fc(x)
    58. if self.training and self.aux_use:
    59. return x, aux1, aux2
    60. return x;
    61. def _initialize_weights_(self):
    62. for v in self.modules():
    63. if isinstance(v, nn.Conv2d):
    64. nn.init.xavier_uniform_(v.weight)
    65. if v.bias is not None:
    66. nn.init.constant_(v.bias, 0)
    67. if isinstance(v, nn.Linear):
    68. nn.init.xavier_uniform_(v.weight)
    69. if v.bias is not None:
    70. nn.init.constant_(v.bias, 0)
    71. # set BasicConv2d class
    72. class BasicConv2d(nn.Module):
    73. def __init__(self, in_channels, out_channels, **kwargs):
    74. super(BasicConv2d, self).__init__()
    75. self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
    76. self.relu = nn.ReLU(inplace=True)
    77. def forward(self, x):
    78. x = self.conv(x)
    79. x = self.relu(x)
    80. return x;
    81. # set Inception class
    82. class Inception(nn.Module):
    83. # 各分支最后的输出宽高要一样
    84. def __init__(self, in_channels, ch11, ch33_reduce, ch33, ch55_reduce, ch55, pool_proj):
    85. super(Inception, self).__init__()
    86. self.branch1 = BasicConv2d(in_channels, ch11, kernel_size=1)
    87. self.branch2 = nn.Sequential(
    88. BasicConv2d(in_channels, ch33_reduce, kernel_size=1),
    89. BasicConv2d(ch33_reduce, ch33, kernel_size=3, padding=1)
    90. )
    91. self.branch3 = nn.Sequential(
    92. BasicConv2d(in_channels, ch55_reduce, kernel_size=1),
    93. BasicConv2d(ch55_reduce, ch55, kernel_size=5, padding=2)
    94. )
    95. self.branch4 = nn.Sequential(
    96. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
    97. BasicConv2d(in_channels, pool_proj, kernel_size=1)
    98. )
    99. def forward(self, x):
    100. branch1 = self.branch1(x)
    101. branch2 = self.branch2(x)
    102. branch3 = self.branch3(x)
    103. branch4 = self.branch4(x)
    104. outputs = [branch1, branch2, branch3, branch4]
    105. return torch.cat(outputs, dim=1)
    106. # set InceptionAux class
    107. class InceptionAux(nn.Module):
    108. def __init__(self, in_channels, num_classes):
    109. super(InceptionAux, self).__init__()
    110. self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
    111. self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output:[batch,128,4,4]
    112. self.fc1 = nn.Linear(2048, 1024)
    113. self.fc2 = nn.Linear(1024, num_classes)
    114. def forward(self, x):
    115. # Input: Aux1(batch,512,14,14) Aux2(batch,528,14,14)
    116. x = self.averagePool(x)
    117. # output: Aux1(batch,512,4,4) Aux2(batch,528,4,4)
    118. x = self.conv(x)
    119. # output:Aux1、Aux2(batch,128,4,4)
    120. x = torch.flatten(x, start_dim=1)
    121. x = F.dropout(x, 0.5, training=self.training)
    122. # batch × 2048
    123. x = F.relu(self.fc1(x), inplace=True)
    124. x = F.dropout(x, 0.5, training=self.training)
    125. # batch × 1024
    126. x = self.fc2(x)
    127. # batch × num_classes
    128. return x

    4.2 train.py

    1. import os
    2. import sys
    3. import torch
    4. import torch.nn as nn
    5. import torchvision
    6. from torchvision import transforms, datasets
    7. import json
    8. from model import GoogLeNet
    9. import torch.optim as optim
    10. from tqdm import tqdm
    11. def main():
    12. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    13. data_transform = {
    14. 'train': transforms.Compose([
    15. transforms.RandomResizedCrop(224),
    16. transforms.RandomHorizontalFlip(),
    17. transforms.ToTensor(),
    18. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    19. ]),
    20. 'val': transforms.Compose([
    21. transforms.Resize((224, 224)),
    22. transforms.ToTensor(),
    23. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    24. ])
    25. }
    26. data_root = os.path.abspath(os.getcwd())
    27. image_path = os.path.join(data_root, 'data_set', 'flower_data')
    28. assert os.path.exists(image_path), 'file:{} is not exist!'.format(image_path)
    29. # set dataset
    30. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])
    31. val_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])
    32. train_num = len(train_dataset)
    33. val_num = len(val_dataset)
    34. # write dict into file
    35. flower_list = train_dataset.class_to_idx
    36. class_dict = dict((k, v) for v, k in flower_list.items())
    37. json_str = json.dumps(class_dict, indent=4)
    38. with open('./class_indices.json', 'w') as file:
    39. file.write(json_str)
    40. # set dataloader
    41. batch_size = 32
    42. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    43. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    44. print('using {} images for training, {} images for validation.'.format(train_num, val_num))
    45. net = GoogLeNet(num_classes=5, aux_use=True, init_weight=True)
    46. net.to(device)
    47. loss_function = nn.CrossEntropyLoss()
    48. optimizer = optim.Adam(net.parameters(), lr=0.0003)
    49. epochs = 30
    50. best_acc = 0.0
    51. save_path = './GoogLeNet.pth'
    52. train_steps = len(train_loader)
    53. for epoch in range(epochs):
    54. # train
    55. net.train()
    56. epoch_loss = 0.0
    57. train_bar = tqdm(train_loader)
    58. for step, data in enumerate(train_bar):
    59. images, labels = data
    60. optimizer.zero_grad()
    61. output, aux1_output, aux2_output = net(images.to(device))
    62. loss0 = loss_function(output, labels.to(device))
    63. loss1 = loss_function(aux1_output, labels.to(device))
    64. loss2 = loss_function(aux2_output, labels.to(device))
    65. loss = loss0 + 0.3 * loss1 + 0.3 * loss2
    66. loss.backward()
    67. optimizer.step()
    68. # print statistics
    69. epoch_loss += loss.item()
    70. train_bar.desc = 'train epoch[{}/{}] loss:{:.3f}'.format(epoch + 1, epochs, loss)
    71. # validate
    72. net.eval()
    73. acc = 0.0
    74. with torch.no_grad():
    75. val_bar = tqdm(val_loader)
    76. for step, data in enumerate(val_bar):
    77. val_images, val_labels = data
    78. outputs = net(val_images.to(device))
    79. predict_y = torch.argmax(outputs, dim=1)
    80. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    81. val_acc = acc / val_num
    82. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, epoch_loss / train_steps, val_acc))
    83. if val_acc > best_acc:
    84. best_acc = val_acc
    85. torch.save(net.state_dict(), save_path)
    86. print('Finished Training!')
    87. if __name__ == '__main__':
    88. main()

    4.3 predict.py

    1. import os
    2. import torch
    3. import torchvision
    4. from torchvision import transforms
    5. from PIL import Image
    6. import matplotlib.pyplot as plt
    7. import json
    8. from model import GoogLeNet
    9. def main():
    10. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    11. transform = transforms.Compose([
    12. transforms.Resize((224, 224)),
    13. transforms.ToTensor(),
    14. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    15. ])
    16. img_path = './sunflower.jpg'
    17. assert os.path.exists(img_path), 'file:{} is not exist!'.format(img_path)
    18. img = Image.open(img_path)
    19. plt.imshow(img)
    20. # [N,C,H,W]
    21. img = transform(img)
    22. img = torch.unsqueeze(img, dim=0)
    23. # read class_dict
    24. json_path = './class_indices.json'
    25. assert os.path.exists(json_path), 'file:{} is not exist!'.format(json_path)
    26. with open(json_path, 'r') as file:
    27. class_dict = json.load(file)
    28. # create model
    29. net = GoogLeNet(num_classes=5, aux_use=False).to(device)
    30. # load model weights
    31. weight_path = './GoogLeNet.pth'
    32. assert os.path.exists(weight_path), 'file:{} is not exist!'.format(weight_path)
    33. # unexpected_keys里面存放的是辅助分类器aux1与aux2的权重
    34. missing_keys, unexpected_keys = net.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
    35. net.eval()
    36. with torch.no_grad():
    37. outputs = torch.squeeze(net(img.to(device))).cpu()
    38. predict = torch.softmax(outputs, dim=0)
    39. predict_class = torch.argmax(predict).numpy()
    40. print_res = 'class:{} probability:{:.3f}'.format(class_dict[str(predict_class)], predict[predict_class])
    41. plt.title(print_res)
    42. for i in range(len(predict)):
    43. print('class:{:10} probability:{:.3f}'.format(class_dict[str(i)], predict[i]))
    44. plt.show()
    45. if __name__ == '__main__':
    46. main()

    5. 结果

  • 相关阅读:
    JVM篇---第一篇
    有氧运动与无氧运动的区别
    《AI聊天类工具之五——Copilot》
    面试算法39:直方图最大矩形面积
    计算params的参数两和flops
    JAVA中类和对象的认识
    JavaScript:实现ExponentialSearch指数搜索算法(附完整源码)
    AtCoder Beginner Contest 212 E(DP)
    登录界面代码
    快速上手SpringBoot
  • 原文地址:https://blog.csdn.net/qq_61706112/article/details/133135548