• 使用GoogleNet网络实现花朵分类


    一.数据集准备

    新建一个项目文件夹GoogleNet,并在里面建立data_set文件夹用来保存数据集,在data_set文件夹下创建新文件夹"flower_data",点击链接下载花分类数据集https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz,会下载一个压缩包,将它解压到flower_data文件夹下,执行"split_data.py"脚本自动将数据集划分成训练集train和验证集val。

     split.py如下:

    1. import os
    2. from shutil import copy, rmtree
    3. import random
    4. def mk_file(file_path: str):
    5. if os.path.exists(file_path):
    6. # 如果文件夹存在,则先删除原文件夹在重新创建
    7. rmtree(file_path)
    8. os.makedirs(file_path)
    9. def main():
    10. # 保证随机可复现
    11. random.seed(0)
    12. # 将数据集中10%的数据划分到验证集中
    13. split_rate = 0.1
    14. # 指向你解压后的flower_photos文件夹
    15. cwd = os.getcwd()
    16. data_root = os.path.join(cwd, "flower_data")
    17. origin_flower_path = os.path.join(data_root, "flower_photos")
    18. assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)
    19. flower_class = [cla for cla in os.listdir(origin_flower_path)
    20. if os.path.isdir(os.path.join(origin_flower_path, cla))]
    21. # 建立保存训练集的文件夹
    22. train_root = os.path.join(data_root, "train")
    23. mk_file(train_root)
    24. for cla in flower_class:
    25. # 建立每个类别对应的文件夹
    26. mk_file(os.path.join(train_root, cla))
    27. # 建立保存验证集的文件夹
    28. val_root = os.path.join(data_root, "val")
    29. mk_file(val_root)
    30. for cla in flower_class:
    31. # 建立每个类别对应的文件夹
    32. mk_file(os.path.join(val_root, cla))
    33. for cla in flower_class:
    34. cla_path = os.path.join(origin_flower_path, cla)
    35. images = os.listdir(cla_path)
    36. num = len(images)
    37. # 随机采样验证集的索引
    38. eval_index = random.sample(images, k=int(num*split_rate))
    39. for index, image in enumerate(images):
    40. if image in eval_index:
    41. # 将分配至验证集中的文件复制到相应目录
    42. image_path = os.path.join(cla_path, image)
    43. new_path = os.path.join(val_root, cla)
    44. copy(image_path, new_path)
    45. else:
    46. # 将分配至训练集中的文件复制到相应目录
    47. image_path = os.path.join(cla_path, image)
    48. new_path = os.path.join(train_root, cla)
    49. copy(image_path, new_path)
    50. print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
    51. print()
    52. print("processing done!")
    53. if __name__ == '__main__':
    54. main()

    之后会在文件夹下生成train和val数据集,到此,完成了数据集的准备。

    二.定义网络

    新建model.py,参照GoogleNet的网络结构和pytorch官方给出的代码,对代码进行略微的修改即可,在他的代码里首先定义了三个类BasicConv2d、Inception、InceptionAux,即基础卷积、Inception模块、辅助分类器三个部分,接着定义了GoogleNet类,对上述三个类进行调用,完成前向传播。

    pytorch官方示例GoogleNet代码

    1. import warnings
    2. from collections import namedtuple
    3. from functools import partial
    4. from typing import Any, Callable, List, Optional, Tuple
    5. import torch
    6. import torch.nn as nn
    7. import torch.nn.functional as F
    8. from torch import Tensor
    9. class GoogLeNet(nn.Module):
    10. def __init__(self, num_classes = 1000, aux_logits = True, transform_input = False, init_weights = True):
    11. super(GoogLeNet,self).__init__()
    12. self.aux_logits = aux_logits
    13. self.transform_input = transform_input
    14. self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) #3为输入通道数,64为输出通道数
    15. self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    16. self.conv2 = BasicConv2d(64, 64, kernel_size=1)
    17. self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
    18. self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    19. self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
    20. self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
    21. self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    22. self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
    23. self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
    24. self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
    25. self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
    26. self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
    27. self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
    28. self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
    29. self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
    30. if aux_logits:
    31. self.aux1 = InceptionAux(512, num_classes)
    32. self.aux2 = InceptionAux(528, num_classes)
    33. self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) #自适应平均池化下采样,对于任意尺寸的特征向量,都得到1*1特征矩阵
    34. self.dropout = nn.Dropout(0.4)
    35. self.fc = nn.Linear(1024, num_classes)
    36. if init_weights:
    37. for m in self.modules():
    38. if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
    39. torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
    40. elif isinstance(m, nn.BatchNorm2d):
    41. nn.init.constant_(m.weight, 1)
    42. nn.init.constant_(m.bias, 0)
    43. def _transform_input(self, x):
    44. if self.transform_input:
    45. x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
    46. x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
    47. x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
    48. x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
    49. return x
    50. def forward(self, x):
    51. x = self._transform_input(x)
    52. # N x 3 x 224 x 224 ---- batch_size cahnnel height width
    53. x = self.conv1(x)
    54. # N x 64 x 112 x 112
    55. x = self.maxpool1(x)
    56. # N x 64 x 56 x 56
    57. x = self.conv2(x)
    58. # N x 64 x 56 x 56
    59. x = self.conv3(x)
    60. # N x 192 x 56 x 56
    61. x = self.maxpool2(x)
    62. # N x 192 x 28 x 28
    63. x = self.inception3a(x)
    64. # N x 256 x 28 x 28
    65. x = self.inception3b(x)
    66. # N x 480 x 28 x 28
    67. x = self.maxpool3(x)
    68. # N x 480 x 14 x 14
    69. x = self.inception4a(x)
    70. # N x 512 x 14 x 14
    71. if self.training and self.aux_logits:
    72. aux1 = self.aux1(x)
    73. x = self.inception4b(x)
    74. # N x 512 x 14 x 14
    75. x = self.inception4c(x)
    76. # N x 512 x 14 x 14
    77. x = self.inception4d(x)
    78. # N x 528 x 14 x 14
    79. if self.training and self.aux_logits:
    80. aux2 = self.aux2(x)
    81. x = self.inception4e(x)
    82. # N x 832 x 14 x 14
    83. x = self.maxpool4(x)
    84. # N x 832 x 7 x 7
    85. x = self.inception5a(x)
    86. # N x 832 x 7 x 7
    87. x = self.inception5b(x)
    88. # N x 1024 x 7 x 7
    89. x = self.avgpool(x)
    90. # N x 1024 x 1 x 1
    91. x = torch.flatten(x, 1)
    92. # N x 1024
    93. x = self.dropout(x)
    94. x = self.fc(x)
    95. # N x 1000 (num_classes)
    96. if self.training and self.aux_logits:
    97. return x, aux2, aux1
    98. return x
    99. class Inception(nn.Module):
    100. def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
    101. super(Inception, self).__init__()
    102. self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
    103. self.branch2 = nn.Sequential(
    104. BasicConv2d(in_channels, ch3x3red, kernel_size=1),
    105. BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) # 保证输出大小等于输入大小
    106. )
    107. self.branch3 = nn.Sequential(
    108. BasicConv2d(in_channels, ch5x5red, kernel_size=1),
    109. BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1), # 保证输出大小等于输入大小
    110. )
    111. self.branch4 = nn.Sequential(
    112. nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
    113. BasicConv2d(in_channels, pool_proj, kernel_size=1),
    114. )
    115. def forward(self, x):
    116. branch1 = self.branch1(x)
    117. branch2 = self.branch2(x)
    118. branch3 = self.branch3(x)
    119. branch4 = self.branch4(x)
    120. outputs = [branch1, branch2, branch3, branch4]
    121. return torch.cat(outputs, 1) #batch channel hetght width,在channel上拼接
    122. class InceptionAux(nn.Module):
    123. def __init__(self, in_channels, num_classes):
    124. super(InceptionAux, self).__init__()
    125. self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
    126. self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output[batch, 128, 4, 4]
    127. self.fc1 = nn.Linear(2048, 1024)
    128. self.fc2 = nn.Linear(1024, num_classes)
    129. def forward(self, x):
    130. # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
    131. x = self.averagePool(x)
    132. # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
    133. x = self.conv(x)
    134. # N x 128 x 4 x 4
    135. x = torch.flatten(x, 1)
    136. x = F.dropout(x, 0.5, training=self.training)
    137. # N x 2048
    138. x = F.relu(self.fc1(x), inplace=True)
    139. x = F.dropout(x, 0.5, training=self.training)
    140. # N x 1024
    141. x = self.fc2(x)
    142. # N x 1000 (num_classes)
    143. return x
    144. class BasicConv2d(nn.Module):
    145. def __init__(self, in_channels, out_channels, **kwargs):
    146. super(BasicConv2d, self).__init__()
    147. self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
    148. self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
    149. def forward(self, x):
    150. x = self.conv(x)
    151. x = self.bn(x)
    152. return F.relu(x, inplace=True)
    153. if __name__ == "__main__":
    154. googlenet = GoogLeNet(num_classes = 3, aux_logits = True, transform_input = False, init_weights = True)
    155. in_data = torch.randn(1, 3, 224, 224)
    156. out = googlenet(in_data)
    157. print(out)

    完成网络的定义之后,可以单独执行一下这个文件,用来验证网络定义的是否正确。如果可以正确输出,就没问题。

    三.开始训练

     加载数据集

    首先定义一个字典,用于用于对train和val进行预处理,包括裁剪成224*224大小,训练集随机水平翻转(一般验证集不需要此操作),转换成张量,图像归一化。

    然后利用DataLoader模块加载数据集,并设置batch_size为32,同时,设置数据加载器的工作进程数nw,加快速度。

    1. data_transform = {
    2. "train": transforms.Compose([transforms.RandomResizedCrop(224),
    3. transforms.RandomHorizontalFlip(),
    4. transforms.ToTensor(),
    5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    6. "val": transforms.Compose([transforms.Resize((224, 224)),
    7. transforms.ToTensor(),
    8. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    9. # 获取数据集路径
    10. image_path = os.path.join(os.getcwd(), "data_set", "flower_data") # flower data set path
    11. assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    12. # 加载数据集,准备读取
    13. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])
    14. validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"])
    15. nw = min([os.cpu_count(), 32 if 32 > 1 else 0, 8]) # number of workers
    16. print(f'Using {nw} dataloader workers every process')
    17. # 加载数据集
    18. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=nw)
    19. validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=32, shuffle=False, num_workers=nw)
    20. train_num = len(train_dataset)
    21. val_num = len(validate_dataset)
    22. print(f"using {train_num} images for training, {val_num} images for validation.")

    生成json文件

    将训练数据集的类别标签转换为字典格式,并将其写入名为'class_indices.json'的文件中。

    1. train_dataset中获取类别标签到索引的映射关系,存储在flower_list变量中。
    2. 使用列表推导式将flower_list中的键值对反转,得到一个新的字典cla_dict,其中键是原始类别标签,值是对应的索引。
    3. 使用json.dumps()函数将cla_dict转换为JSON格式的字符串,设置缩进为4个空格。
    4. 使用with open()语句以写入模式打开名为'class_indices.json'的文件,并将JSON字符串写入文件。
    1. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 雏菊 蒲公英 玫瑰 向日葵 郁金香
    2. # 从训练集中获取类别标签到索引的映射关系,存储在flower_list变量
    3. flower_list = train_dataset.class_to_idx
    4. # 使用列表推导式将flower_list中的键值对反转,得到一个新的字典cla_dict
    5. cla_dict = dict((val, key) for key, val in flower_list.items())
    6. # write dict into json file,将cla_dict转换为JSON格式的字符串
    7. json_str = json.dumps(cla_dict, indent=4)
    8. with open('class_indices.json', 'w') as json_file:
    9. json_file.write(json_str)

    定义网络,开始训练

    首先定义网络对象net,传入要分类的类别数为5,使用辅助分类器并初始化权重;在这里训练30轮,并使用train_bar = tqdm(train_loader, file=sys.stdout)来可视化训练进度条,loss计算采用了GoogleNet原论文的方法,进行加权计算,之后再进行反向传播和参数更新;同时,每一轮训练完成都要进行学习率更新;之后开始对验证集进行计算精确度,完成后保存模型。

    1. net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    2. net.to(device)
    3. loss_function = nn.CrossEntropyLoss()
    4. optimizer = optim.Adam(net.parameters(), lr=0.0003)
    5. sculer = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
    6. epochs = 30
    7. best_acc = 0.0
    8. train_steps = len(train_loader)
    9. for epoch in range(epochs):
    10. # train
    11. net.train()
    12. running_loss = 0.0
    13. train_bar = tqdm(train_loader, file=sys.stdout)
    14. for step, data in enumerate(train_bar):
    15. imgs, labels = data
    16. optimizer.zero_grad()
    17. logits, aux_logits2, aux_logits1 = net(imgs.to(device))
    18. loss0 = loss_function(logits, labels.to(device))
    19. loss1 = loss_function(aux_logits1, labels.to(device))
    20. loss2 = loss_function(aux_logits2, labels.to(device))
    21. loss = loss0 + loss1 * 0.3 + loss2 * 0.3
    22. loss.backward()
    23. optimizer.step()
    24. # print statistics
    25. running_loss += loss.item()
    26. train_bar.desc = f"train epoch[{epoch+1}/{epochs}] loss:{loss:.3f}"
    27. sculer.step()
    28. # validate
    29. net.eval()
    30. acc = 0.0 # accumulate accurate number / epoch
    31. with torch.no_grad():
    32. val_bar = tqdm(validate_loader, file=sys.stdout)
    33. for val_data in val_bar:
    34. val_imgs, val_labels = val_data
    35. outputs = net(val_imgs.to(device)) # eval model only have last output layer
    36. predict_y = torch.max(outputs, dim=1)[1]
    37. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    38. val_accurate = acc / val_num
    39. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
    40. (epoch + 1, running_loss / train_steps, val_accurate))
    41. if val_accurate > best_acc:
    42. best_acc = val_accurate
    43. torch.save(net,"./googleNet.pth")
    44. print('Finished Training')

    最后对代码进行整理,完整的train.py如下

    1. import os
    2. import sys
    3. import json
    4. import torch
    5. import torch.nn as nn
    6. from torchvision import transforms, datasets
    7. from torch.utils.data import DataLoader
    8. import torch.optim as optim
    9. from tqdm import tqdm
    10. from model import GoogLeNet
    11. def main():
    12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    13. print(f"using {device} device.")
    14. data_transform = {
    15. "train": transforms.Compose([transforms.RandomResizedCrop(224),
    16. transforms.RandomHorizontalFlip(),
    17. transforms.ToTensor(),
    18. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    19. "val": transforms.Compose([transforms.Resize((224, 224)),
    20. transforms.ToTensor(),
    21. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    22. # 获取数据集路径
    23. image_path = os.path.join(os.getcwd(), "data_set", "flower_data") # flower data set path
    24. assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    25. # 加载数据集,准备读取
    26. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])
    27. validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"])
    28. nw = min([os.cpu_count(), 32 if 32 > 1 else 0, 8]) # number of workers
    29. print(f'Using {nw} dataloader workers every process')
    30. # 加载数据集
    31. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=nw)
    32. validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=32, shuffle=False, num_workers=nw)
    33. train_num = len(train_dataset)
    34. val_num = len(validate_dataset)
    35. print(f"using {train_num} images for training, {val_num} images for validation.")
    36. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} 雏菊 蒲公英 玫瑰 向日葵 郁金香
    37. # 从训练集中获取类别标签到索引的映射关系,存储在flower_list变量
    38. flower_list = train_dataset.class_to_idx
    39. # 使用列表推导式将flower_list中的键值对反转,得到一个新的字典cla_dict
    40. cla_dict = dict((val, key) for key, val in flower_list.items())
    41. # write dict into json file,将cla_dict转换为JSON格式的字符串
    42. json_str = json.dumps(cla_dict, indent=4)
    43. with open('class_indices.json', 'w') as json_file:
    44. json_file.write(json_str)
    45. """如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
    46. 官方的模型中使用了bn层以及改了一些参数,不能混用
    47. import torchvision
    48. net = torchvision.models.googlenet(num_classes=5)
    49. model_dict = net.state_dict()
    50. # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    51. pretrain_model = torch.load("googlenet.pth")
    52. del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
    53. "aux2.fc2.weight", "aux2.fc2.bias",
    54. "fc.weight", "fc.bias"]
    55. pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    56. model_dict.update(pretrain_dict)
    57. net.load_state_dict(model_dict)"""
    58. net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    59. net.to(device)
    60. loss_function = nn.CrossEntropyLoss()
    61. optimizer = optim.Adam(net.parameters(), lr=0.0003)
    62. sculer = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
    63. epochs = 30
    64. best_acc = 0.0
    65. train_steps = len(train_loader)
    66. for epoch in range(epochs):
    67. # train
    68. net.train()
    69. running_loss = 0.0
    70. train_bar = tqdm(train_loader, file=sys.stdout)
    71. for step, data in enumerate(train_bar):
    72. imgs, labels = data
    73. optimizer.zero_grad()
    74. logits, aux_logits2, aux_logits1 = net(imgs.to(device))
    75. loss0 = loss_function(logits, labels.to(device))
    76. loss1 = loss_function(aux_logits1, labels.to(device))
    77. loss2 = loss_function(aux_logits2, labels.to(device))
    78. loss = loss0 + loss1 * 0.3 + loss2 * 0.3
    79. loss.backward()
    80. optimizer.step()
    81. # print statistics
    82. running_loss += loss.item()
    83. train_bar.desc = f"train epoch[{epoch+1}/{epochs}] loss:{loss:.3f}"
    84. sculer.step()
    85. # validate
    86. net.eval()
    87. acc = 0.0 # accumulate accurate number / epoch
    88. with torch.no_grad():
    89. val_bar = tqdm(validate_loader, file=sys.stdout)
    90. for val_data in val_bar:
    91. val_imgs, val_labels = val_data
    92. outputs = net(val_imgs.to(device)) # eval model only have last output layer
    93. predict_y = torch.max(outputs, dim=1)[1]
    94. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    95. val_accurate = acc / val_num
    96. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
    97. (epoch + 1, running_loss / train_steps, val_accurate))
    98. if val_accurate > best_acc:
    99. best_acc = val_accurate
    100. torch.save(net,"./googleNet.pth")
    101. print('Finished Training')
    102. if __name__ == '__main__':
    103. main()

    四.模型预测

    新建一个predict.py文件用于预测,将输入图像处理后转换成张量格式,img = torch.unsqueeze(img, dim=0)是在输入图像张量 img 的第一个维度上增加一个大小为1的维度,因此将图像张量的形状从 [通道数, 高度, 宽度 ] 转换为 [1, 通道数, 高度, 宽度]。然后加载模型进行预测,并打印出结果,同时可视化。

    1. import os
    2. import json
    3. import torch
    4. from PIL import Image
    5. from torchvision import transforms
    6. import matplotlib.pyplot as plt
    7. from model import GoogLeNet
    8. def main():
    9. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    10. data_transform = transforms.Compose(
    11. [transforms.Resize((224, 224)),
    12. transforms.ToTensor(),
    13. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    14. # load image
    15. img = Image.open("./2678588376_6ca64a4a54_n.jpg")
    16. plt.imshow(img)
    17. # [N, C, H, W]
    18. img = data_transform(img)
    19. # expand batch dimension
    20. img = torch.unsqueeze(img, dim=0)
    21. # read class_indict
    22. with open("./class_indices.json", "r") as f:
    23. class_indict = json.load(f)
    24. # create model
    25. model = GoogLeNet(num_classes=5, aux_logits=False).to(device)
    26. model=torch.load("/home/lm/GoogleNet/googleNet.pth")
    27. model.eval()
    28. with torch.no_grad():
    29. # predict class
    30. output = torch.squeeze(model(img.to(device))).cpu()
    31. predict = torch.softmax(output, dim=0)
    32. predict_class = torch.argmax(predict).numpy()
    33. print_result = f"class: {class_indict[str(predict_class)]} prob: {predict[predict_class].numpy():.3}"
    34. plt.title(print_result)
    35. for i in range(len(predict)):
    36. print(f"class: {class_indict[str(i)]:10} prob: {predict[i].numpy():.3}")
    37. plt.show()
    38. if __name__ == '__main__':
    39. main()

    预测结果

    五.模型可视化

    将生成的pth文件导入netron工具,可视化结果为

    发现很不清晰,因此将它转换成多用于嵌入式设备部署的onnx格式

    编写onnx.py

    1. import torch
    2. import torchvision
    3. from model import GoogLeNet
    4. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    5. model = GoogLeNet(num_classes=5, aux_logits=False).to(device)
    6. model=torch.load("/home/lm/GoogleNet/googleNet.pth")
    7. model.eval()
    8. example = torch.ones(1, 3, 244, 244)
    9. example = example.to(device)
    10. torch.onnx.export(model, example, "googleNet.onnx", verbose=True, opset_version=11)

     将生成的onnx文件导入,这样的可视化清晰了许多

    六.模型改进

    发现去掉学习率更新会提高准确率(从70%提升到83%),因此把train.py里面对应部分删掉。

    还有其他方法会在之后进行补充。

    源码地址:链接: https://pan.baidu.com/s/1FGcGwrNAZZSEocPORD3bZg 提取码: xsfn 复制这段内容后打开百度网盘手机App,操作更方便哦

  • 相关阅读:
    开源数字基础设施 项目 -- Speckle
    《吐血整理》高级系列教程-吃透Fiddler抓包教程(29)-Fiddler如何抓取Android7.0以上的Https包-终篇
    Android 固定WIFI热点路由IP
    国庆day3---网络编程知识点脑图整合
    既然有HTTP协议,为什么还要有RPC
    Auto.js中的悬浮窗
    如何减少电气设备漏电问题,其解决方案有哪些?
    消息队列-RabbitMQ-消息确认机制
    Spring Cloud Consul
    数据可视化(箱线图、直方图、散点图、联合分布图)
  • 原文地址:https://blog.csdn.net/qq_46454669/article/details/133960577