• 第P9周-YOLOv5Backbone模块


    CSP Bottleneck块和C3 类的设计使其非常适合目标检测任务,充分考虑了多尺度特征融合、梯度流动和计算效率等因素。C3 类以及CSP(Cross Stage Partial) Bottleneck块作为YOLOv5中的一部分,具有以下优势,相对于传统的普通神经网络:

    1. 特定任务定制:C3 类和CSP Bottleneck块是专门为目标检测任务设计的。它们在特征提取和融合方面进行了特定的优化,有助于提高目标检测性能。传统的神经网络结构通常用于通用的图像分类任务,不一定适用于目标检测。

    2. 多尺度特征融合:CSP Bottleneck块具有特殊的结构,允许多尺度的特征信息进行有效融合。这对于目标检测非常关键,因为目标可能具有不同尺寸和比例,需要在不同尺度下进行检测。

    3. 快捷连接:CSP Bottleneck块允许引入快捷连接,这对于信息流动和梯度传播至关重要。传统的神经网络通常没有这种结构,而YOLOv5中的C3块通过快捷连接促进了梯度的传递,减轻了梯度消失问题,有助于训练的稳定性。

    4. 高效性能:尽管CSP Bottleneck块具有更复杂的结构,但它们经过有效的设计,不会引入过多的计算复杂性。这使得YOLOv5能够在相对较短的时间内完成目标检测任务。

    一、前期工作

    1.1 导入数据集

    1. import torch
    2. import torch.nn as nn
    3. import torchvision.transforms as transforms
    4. import torchvision
    5. from torchvision import transforms, datasets
    6. import os,PIL,pathlib,warnings
    7. warnings.filterwarnings("ignore") #忽略警告信息
    8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    9. print(device)
    10. import os,PIL,random,pathlib
    11. data_dir = 'D:/T6star/'
    12. data_dir = pathlib.Path(data_dir)
    13. data_paths = list(data_dir.glob('*'))
    14. classeNames = [str(path).split("\\")[2] for path in data_paths]
    15. print(classeNames)

     1.2  标准化处理

    1. # 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
    2. train_transforms = transforms.Compose([
    3. transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
    4. # transforms.RandomHorizontalFlip(), # 随机水平翻转
    5. transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    6. transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
    7. mean=[0.485, 0.456, 0.406],
    8. std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
    9. ])
    10. test_transform = transforms.Compose([
    11. transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
    12. transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    13. transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
    14. mean=[0.485, 0.456, 0.406],
    15. std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
    16. ])
    17. total_data = datasets.ImageFolder("D:/P8/weather_photos/",transform=train_transforms)
    18. print(total_data)

    1.3 划分数据集

    1. train_size = int(0.8 * len(total_data))
    2. test_size = len(total_data) - train_size
    3. train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
    4. print(train_dataset, test_dataset)

     

    1.4 设置数据加载器

    1. batch_size = 4
    2. train_dl = torch.utils.data.DataLoader(train_dataset,
    3. batch_size=batch_size,
    4. shuffle=True,
    5. num_workers=1)
    6. test_dl = torch.utils.data.DataLoader(test_dataset,
    7. batch_size=batch_size,
    8. shuffle=True,

     二、搭建包含CSP Bottleneck块和C3 类的YOLOv5的主干网络

    1. import torch.nn.functional as F
    2. def autopad(k, p=None): # kernel, padding
    3. # Pad to 'same'
    4. if p is None:
    5. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
    6. return p
    7. class Conv(nn.Module):
    8. # Standard convolution
    9. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
    10. super().__init__()
    11. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)# 定义卷积层
    12. self.bn = nn.BatchNorm2d(c2)
    13. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
    14. def forward(self, x):
    15. return self.act(self.bn(self.conv(x)))
    16. class Bottleneck(nn.Module):
    17. # Standard bottleneck
    18. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
    19. super().__init__()
    20. c_ = int(c2 * e) # hidden channels
    21. self.cv1 = Conv(c1, c_, 1, 1)
    22. self.cv2 = Conv(c_, c2, 3, 1, g=g)
    23. self.add = shortcut and c1 == c2
    24. def forward(self, x):
    25. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
    26. class C3(nn.Module):
    27. # CSP Bottleneck with 3 convolutions
    28. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
    29. super().__init__()
    30. c_ = int(c2 * e) # hidden channels
    31. self.cv1 = Conv(c1, c_, 1, 1)
    32. self.cv2 = Conv(c1, c_, 1, 1)
    33. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
    34. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
    35. def forward(self, x):
    36. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
    37. class SPPF(nn.Module):
    38. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
    39. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
    40. super().__init__()
    41. c_ = c1 // 2 # hidden channels
    42. self.cv1 = Conv(c1, c_, 1, 1)
    43. self.cv2 = Conv(c_ * 4, c2, 1, 1)
    44. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
    45. def forward(self, x):
    46. x = self.cv1(x)
    47. with warnings.catch_warnings():
    48. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
    49. y1 = self.m(x)
    50. y2 = self.m(y1)
    51. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
    52. """
    53. 这个是YOLOv5, 6.0版本的主干网络,这里进行复现
    54. (注:有部分删改,详细讲解将在后续进行展开)
    55. """
    56. class YOLOv5_backbone(nn.Module):
    57. def __init__(self):
    58. super(YOLOv5_backbone, self).__init__()
    59. self.Conv_1 = Conv(3, 64, 3, 2, 2)
    60. self.Conv_2 = Conv(64, 128, 3, 2)
    61. self.C3_3 = C3(128,128)
    62. self.Conv_4 = Conv(128, 256, 3, 2)
    63. self.C3_5 = C3(256,256)
    64. self.Conv_6 = Conv(256, 512, 3, 2)
    65. self.C3_7 = C3(512,512)
    66. self.Conv_8 = Conv(512, 1024, 3, 2)
    67. self.C3_9 = C3(1024, 1024)
    68. self.SPPF = SPPF(1024, 1024, 5)
    69. # 全连接网络层,用于分类
    70. self.classifier = nn.Sequential(
    71. nn.Linear(in_features=65536, out_features=100),
    72. nn.ReLU(),
    73. nn.Linear(in_features=100, out_features=4)
    74. )
    75. def forward(self, x):
    76. x = self.Conv_1(x)
    77. x = self.Conv_2(x)
    78. x = self.C3_3(x)
    79. x = self.Conv_4(x)
    80. x = self.C3_5(x)
    81. x = self.Conv_6(x)
    82. x = self.C3_7(x)
    83. x = self.Conv_8(x)
    84. x = self.C3_9(x)
    85. x = self.SPPF(x)
    86. x = torch.flatten(x, start_dim=1)
    87. x = self.classifier(x)
    88. return x
    89. device = "cuda" if torch.cuda.is_available() else "cpu"
    90. print("Using {} device".format(device))
    91. model = YOLOv5_backbone().to(device)
    92. print(model)

    autopad(k, p) 函数:

    • 这个函数用于计算卷积操作中的自动填充大小。
    • 如果没有提供填充参数 p,则它会根据卷积核大小 k 的情况自动计算填充大小,以实现“same”填充,即输入和输出特征图具有相同的空间维度。

    Conv 类:

    • 这个类定义了标准的卷积层,包括卷积、批量归一化和激活函数等。
    • 构造函数接受参数 c1(输入通道数)、c2(输出通道数)、k(卷积核大小)、s(步幅)、p(填充)、g(分组卷积)、act(激活函数类型)等。
    • forward 方法执行卷积、批量归一化和激活函数操作,并返回输出。

    Bottleneck 类:

    • 这个类定义了标准的瓶颈块,用于深层网络中的特征提取。
    • 构造函数接受参数 c1(输入通道数)、c2(输出通道数)、shortcut(是否包含快捷连接)、g(分组卷积)、e(扩展系数)、groups 卷积操作中的分组参数,通常为1,表示标准卷积。在深度可分离卷积等操作中,该值可能不为1。   bias  设置为False,表示卷积操作不使用偏置项。
    • forward 方法执行一系列卷积操作,包括卷积核为1x1和3x3的卷积,以及可选的快捷连接操作。   

    C3 类:

    • 这个类定义了CSP(Cross Stage Partial) Bottleneck块,它包含3个卷积操作。
    • forward 方法执行一系列卷积操作,包括卷积核为1x1的卷积、2个不同的卷积核为1x1的卷积和可选的快捷连接操作。

    SPPF 类:

    SPPF层是指"Spatial Pyramid Pooling - Fast"层,它是YOLOv5中的一种特殊层,用于多尺度特征融合。SPPF层的作用是对输入特征图进行空间金字塔池化,以捕获不同尺度的特征信息,从而提高目标检测性能。

    self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

    nn.MaxPool2d 是PyTorch中的最大池化层,用于进行池化操作。kernel_size 参数设定了池化核的大小,这里使用了k 作为参数,表示池化核的大小为k x kstride 参数设置为1,表示池化操作的滑动步幅为1。padding 参数设置为k // 2,表示对输入特征图进行填充以保持输出特征图的尺寸与输入相同。

    nn.MaxPool2d 层被用于执行空间金字塔池化,其中kernel_size 的不同取值对应于不同尺度的池化区域。这有助于在不同尺度上捕获特征信息,从而提高目标检测性能。在这个特定的情况下,k 的取值是5913,相应地表示了3个不同尺度的池化操作,从而构成了空间金字塔池化。

    • 这个类定义了SPPF(Spatial Pyramid Pooling - Fast)层,用于多尺度特征融合。
    • forward 方法执行空间金字塔池化操作,将不同尺度的特征图拼接在一起。

    YOLOv5_backbone 类:

    • 定义主干网络的各个层(Conv_1Conv_2C3_3Conv_4C3_5Conv_6C3_7Conv_8C3_9SPPF)用于逐步提取输入图像的特征。。
    • classifier 是一个全连接网络层,用于分类任务。
    • forward 方法定义了整个主干网络的前向传播过程,包括各个层次的卷积和特征融合。

    小结:        

    在上述代码中,forward 函数在不同类中的作用如下:

    1. forward 函数在 Conv 类中的作用:

      • Conv 类代表标准的卷积层,其 forward 函数用于执行卷积操作。
      • 输入 x 是卷积层的输入特征图。
      • 通过卷积操作、批量归一化(Batch Normalization)、激活函数等一系列操作,将输入特征图 x 转化为经过卷积层处理后的输出。
    2. forward 函数在 Bottleneck 类中的作用:

      • Bottleneck 类代表标准的瓶颈块,其 forward 函数用于执行瓶颈块的前向传播。
      • 输入 x 是瓶颈块的输入特征图。
      • 瓶颈块包括一系列卷积层和残差连接(如果 shortcutTrue),将输入特征图 x 转化为经过瓶颈块处理后的输出。
    3. forward 函数在 C3 类中的作用:

      • C3 类代表一种特殊的 CSP(Cross-Stage-Partial)瓶颈块,包括三个卷积操作。
      • 输入 xC3 瓶颈块的输入特征图。
      • forward 函数首先执行两个不同卷积操作 self.cv1self.cv2,然后将它们的输出与输入 x 进行拼接(Concatenate),并传递给第三个卷积操作 self.cv3
      • 最后,通过一系列堆叠的 Bottleneck 块(self.m 中的循环),将输入特征图 x 通过多次瓶颈块的处理,生成经过 C3 块处理后的输出。
    4. forward 函数在 SPPF 类中的作用:

      • SPPF 类代表空间金字塔池化(Spatial Pyramid Pooling - Fast)层。
      • 输入 x 是SPPF层的输入特征图。
      • forward 函数执行空间金字塔池化操作,将输入特征图 x 通过最大池化操作 self.m 在不同尺度上池化,然后将池化后的结果进行拼接,生成SPPF层的输出特征。
    5. forward 函数在 YOLOv5_backbone 类中的作用:

      • YOLOv5_backbone 类代表整个YOLOv5的主干网络,包括多个卷积层和瓶颈块,以及SPPF层。
      • forward 函数按照网络的顺序将输入特征图 x 通过各个组件,包括卷积层、瓶颈块、SPPF层,然后通过全连接网络层 self.classifier 进行分类。
      • 返回经过主干网络处理后的特征表示,用于目标检测的进一步处理。

     运行结果:

    1. YOLOv5_backbone(
    2. (Conv_1): Conv(
    3. (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(2, 2), bias=False)
    4. (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    5. (act): SiLU()
    6. )
    7. (Conv_2): Conv(
    8. (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    9. (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    10. (act): SiLU()
    11. )
    12. (C3_3): C3(
    13. (cv1): Conv(
    14. (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    15. (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    16. (act): SiLU()
    17. )
    18. (cv2): Conv(
    19. (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    20. (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    21. (act): SiLU()
    22. )
    23. (cv3): Conv(
    24. (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    25. (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    26. (act): SiLU()
    27. )
    28. (m): Sequential(
    29. (0): Bottleneck(
    30. (cv1): Conv(
    31. (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    32. (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    33. (act): SiLU()
    34. )
    35. (cv2): Conv(
    36. (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    37. (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    38. (act): SiLU()
    39. )
    40. )
    41. )
    42. )
    43. (Conv_4): Conv(
    44. (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    45. (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    46. (act): SiLU()
    47. )
    48. (C3_5): C3(
    49. (cv1): Conv(
    50. (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    51. (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    52. (act): SiLU()
    53. )
    54. (cv2): Conv(
    55. (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    56. (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    57. (act): SiLU()
    58. )
    59. (cv3): Conv(
    60. (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    61. (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    62. (act): SiLU()
    63. )
    64. (m): Sequential(
    65. (0): Bottleneck(
    66. (cv1): Conv(
    67. (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    68. (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    69. (act): SiLU()
    70. )
    71. (cv2): Conv(
    72. (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    73. (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    74. (act): SiLU()
    75. )
    76. )
    77. )
    78. )
    79. (Conv_6): Conv(
    80. (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    81. (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    82. (act): SiLU()
    83. )
    84. (C3_7): C3(
    85. (cv1): Conv(
    86. (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    87. (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    88. (act): SiLU()
    89. )
    90. (cv2): Conv(
    91. (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    92. (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    93. (act): SiLU()
    94. )
    95. (cv3): Conv(
    96. (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    97. (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    98. (act): SiLU()
    99. )
    100. (m): Sequential(
    101. (0): Bottleneck(
    102. (cv1): Conv(
    103. (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    104. (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    105. (act): SiLU()
    106. )
    107. (cv2): Conv(
    108. (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    109. (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    110. (act): SiLU()
    111. )
    112. )
    113. )
    114. )
    115. (Conv_8): Conv(
    116. (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    117. (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    118. (act): SiLU()
    119. )
    120. (C3_9): C3(
    121. (cv1): Conv(
    122. (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    123. (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    124. (act): SiLU()
    125. )
    126. (cv2): Conv(
    127. (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    128. (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    129. (act): SiLU()
    130. )
    131. (cv3): Conv(
    132. (conv): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    133. (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    134. (act): SiLU()
    135. )
    136. (m): Sequential(
    137. (0): Bottleneck(
    138. (cv1): Conv(
    139. (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    140. (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    141. (act): SiLU()
    142. )
    143. (cv2): Conv(
    144. (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    145. (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    146. (act): SiLU()
    147. )
    148. )
    149. )
    150. )
    151. (SPPF): SPPF(
    152. (cv1): Conv(
    153. (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    154. (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    155. (act): SiLU()
    156. )
    157. (cv2): Conv(
    158. (conv): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    159. (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    160. (act): SiLU()
    161. )
    162. (m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
    163. )
    164. (classifier): Sequential(
    165. (0): Linear(in_features=65536, out_features=100, bias=True)
    166. (1): ReLU()
    167. (2): Linear(in_features=100, out_features=4, bias=True)
    168. )
    169. )

     查看模型详情:

    1. # 统计模型参数量以及其他指标
    2. import torchsummary as summary
    3. summary.summary(model, (3, 224, 224))
    1. ----------------------------------------------------------------
    2. Layer (type) Output Shape Param #
    3. ================================================================
    4. Conv2d-1 [-1, 64, 113, 113] 1,728
    5. BatchNorm2d-2 [-1, 64, 113, 113] 128
    6. SiLU-3 [-1, 64, 113, 113] 0
    7. Conv-4 [-1, 64, 113, 113] 0
    8. Conv2d-5 [-1, 128, 57, 57] 73,728
    9. BatchNorm2d-6 [-1, 128, 57, 57] 256
    10. SiLU-7 [-1, 128, 57, 57] 0
    11. Conv-8 [-1, 128, 57, 57] 0
    12. Conv2d-9 [-1, 64, 57, 57] 8,192
    13. BatchNorm2d-10 [-1, 64, 57, 57] 128
    14. SiLU-11 [-1, 64, 57, 57] 0
    15. Conv-12 [-1, 64, 57, 57] 0
    16. Conv2d-13 [-1, 64, 57, 57] 4,096
    17. BatchNorm2d-14 [-1, 64, 57, 57] 128
    18. SiLU-15 [-1, 64, 57, 57] 0
    19. Conv-16 [-1, 64, 57, 57] 0
    20. Conv2d-17 [-1, 64, 57, 57] 36,864
    21. BatchNorm2d-18 [-1, 64, 57, 57] 128
    22. SiLU-19 [-1, 64, 57, 57] 0
    23. Conv-20 [-1, 64, 57, 57] 0
    24. Bottleneck-21 [-1, 64, 57, 57] 0
    25. Conv2d-22 [-1, 64, 57, 57] 8,192
    26. BatchNorm2d-23 [-1, 64, 57, 57] 128
    27. SiLU-24 [-1, 64, 57, 57] 0
    28. Conv-25 [-1, 64, 57, 57] 0
    29. Conv2d-26 [-1, 128, 57, 57] 16,384
    30. BatchNorm2d-27 [-1, 128, 57, 57] 256
    31. SiLU-28 [-1, 128, 57, 57] 0
    32. Conv-29 [-1, 128, 57, 57] 0
    33. C3-30 [-1, 128, 57, 57] 0
    34. Conv2d-31 [-1, 256, 29, 29] 294,912
    35. BatchNorm2d-32 [-1, 256, 29, 29] 512
    36. SiLU-33 [-1, 256, 29, 29] 0
    37. Conv-34 [-1, 256, 29, 29] 0
    38. Conv2d-35 [-1, 128, 29, 29] 32,768
    39. BatchNorm2d-36 [-1, 128, 29, 29] 256
    40. SiLU-37 [-1, 128, 29, 29] 0
    41. Conv-38 [-1, 128, 29, 29] 0
    42. Conv2d-39 [-1, 128, 29, 29] 16,384
    43. BatchNorm2d-40 [-1, 128, 29, 29] 256
    44. SiLU-41 [-1, 128, 29, 29] 0
    45. Conv-42 [-1, 128, 29, 29] 0
    46. Conv2d-43 [-1, 128, 29, 29] 147,456
    47. BatchNorm2d-44 [-1, 128, 29, 29] 256
    48. SiLU-45 [-1, 128, 29, 29] 0
    49. Conv-46 [-1, 128, 29, 29] 0
    50. Bottleneck-47 [-1, 128, 29, 29] 0
    51. Conv2d-48 [-1, 128, 29, 29] 32,768
    52. BatchNorm2d-49 [-1, 128, 29, 29] 256
    53. SiLU-50 [-1, 128, 29, 29] 0
    54. Conv-51 [-1, 128, 29, 29] 0
    55. Conv2d-52 [-1, 256, 29, 29] 65,536
    56. BatchNorm2d-53 [-1, 256, 29, 29] 512
    57. SiLU-54 [-1, 256, 29, 29] 0
    58. Conv-55 [-1, 256, 29, 29] 0
    59. C3-56 [-1, 256, 29, 29] 0
    60. Conv2d-57 [-1, 512, 15, 15] 1,179,648
    61. BatchNorm2d-58 [-1, 512, 15, 15] 1,024
    62. SiLU-59 [-1, 512, 15, 15] 0
    63. Conv-60 [-1, 512, 15, 15] 0
    64. Conv2d-61 [-1, 256, 15, 15] 131,072
    65. BatchNorm2d-62 [-1, 256, 15, 15] 512
    66. SiLU-63 [-1, 256, 15, 15] 0
    67. Conv-64 [-1, 256, 15, 15] 0
    68. Conv2d-65 [-1, 256, 15, 15] 65,536
    69. BatchNorm2d-66 [-1, 256, 15, 15] 512
    70. SiLU-67 [-1, 256, 15, 15] 0
    71. Conv-68 [-1, 256, 15, 15] 0
    72. Conv2d-69 [-1, 256, 15, 15] 589,824
    73. BatchNorm2d-70 [-1, 256, 15, 15] 512
    74. SiLU-71 [-1, 256, 15, 15] 0
    75. Conv-72 [-1, 256, 15, 15] 0
    76. Bottleneck-73 [-1, 256, 15, 15] 0
    77. Conv2d-74 [-1, 256, 15, 15] 131,072
    78. BatchNorm2d-75 [-1, 256, 15, 15] 512
    79. SiLU-76 [-1, 256, 15, 15] 0
    80. Conv-77 [-1, 256, 15, 15] 0
    81. Conv2d-78 [-1, 512, 15, 15] 262,144
    82. BatchNorm2d-79 [-1, 512, 15, 15] 1,024
    83. SiLU-80 [-1, 512, 15, 15] 0
    84. Conv-81 [-1, 512, 15, 15] 0
    85. C3-82 [-1, 512, 15, 15] 0
    86. Conv2d-83 [-1, 1024, 8, 8] 4,718,592
    87. BatchNorm2d-84 [-1, 1024, 8, 8] 2,048
    88. SiLU-85 [-1, 1024, 8, 8] 0
    89. Conv-86 [-1, 1024, 8, 8] 0
    90. Conv2d-87 [-1, 512, 8, 8] 524,288
    91. BatchNorm2d-88 [-1, 512, 8, 8] 1,024
    92. SiLU-89 [-1, 512, 8, 8] 0
    93. Conv-90 [-1, 512, 8, 8] 0
    94. Conv2d-91 [-1, 512, 8, 8] 262,144
    95. BatchNorm2d-92 [-1, 512, 8, 8] 1,024
    96. SiLU-93 [-1, 512, 8, 8] 0
    97. Conv-94 [-1, 512, 8, 8] 0
    98. Conv2d-95 [-1, 512, 8, 8] 2,359,296
    99. BatchNorm2d-96 [-1, 512, 8, 8] 1,024
    100. SiLU-97 [-1, 512, 8, 8] 0
    101. Conv-98 [-1, 512, 8, 8] 0
    102. Bottleneck-99 [-1, 512, 8, 8] 0
    103. Conv2d-100 [-1, 512, 8, 8] 524,288
    104. BatchNorm2d-101 [-1, 512, 8, 8] 1,024
    105. SiLU-102 [-1, 512, 8, 8] 0
    106. Conv-103 [-1, 512, 8, 8] 0
    107. Conv2d-104 [-1, 1024, 8, 8] 1,048,576
    108. BatchNorm2d-105 [-1, 1024, 8, 8] 2,048
    109. SiLU-106 [-1, 1024, 8, 8] 0
    110. Conv-107 [-1, 1024, 8, 8] 0
    111. C3-108 [-1, 1024, 8, 8] 0
    112. Conv2d-109 [-1, 512, 8, 8] 524,288
    113. BatchNorm2d-110 [-1, 512, 8, 8] 1,024
    114. SiLU-111 [-1, 512, 8, 8] 0
    115. Conv-112 [-1, 512, 8, 8] 0
    116. MaxPool2d-113 [-1, 512, 8, 8] 0
    117. MaxPool2d-114 [-1, 512, 8, 8] 0
    118. MaxPool2d-115 [-1, 512, 8, 8] 0
    119. Conv2d-116 [-1, 1024, 8, 8] 2,097,152
    120. BatchNorm2d-117 [-1, 1024, 8, 8] 2,048
    121. SiLU-118 [-1, 1024, 8, 8] 0
    122. Conv-119 [-1, 1024, 8, 8] 0
    123. SPPF-120 [-1, 1024, 8, 8] 0
    124. Linear-121 [-1, 100] 6,553,700
    125. ReLU-122 [-1, 100] 0
    126. Linear-123 [-1, 4] 404
    127. ================================================================
    128. Total params: 21,729,592
    129. Trainable params: 21,729,592
    130. Non-trainable params: 0
    131. ----------------------------------------------------------------
    132. Input size (MB): 0.57
    133. Forward/backward pass size (MB): 137.59
    134. Params size (MB): 82.89
    135. Estimated Total Size (MB): 221.06
    136. ----------------------------------------------------------------
    137. None

    三、训练函数

    3.1 编写训练函数

    1. # 训练循环
    2. def train(dataloader, model, loss_fn, optimizer):
    3. size = len(dataloader.dataset) # 训练集的大小
    4. num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)
    5. train_loss, train_acc = 0, 0 # 初始化训练损失和正确率
    6. for X, y in dataloader: # 获取图片及其标签
    7. X, y = X.to(device), y.to(device)
    8. # 计算预测误差
    9. pred = model(X) # 网络输出
    10. loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
    11. # 反向传播
    12. optimizer.zero_grad() # grad属性归零
    13. loss.backward() # 反向传播
    14. optimizer.step() # 每一步自动更新
    15. # 记录acc与loss
    16. train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
    17. train_loss += loss.item()
    18. train_acc /= size
    19. train_loss /= num_batches
    20. return train_acc, train_loss

    3.2 编写测试函数

    1. def test (dataloader, model, loss_fn):
    2. size = len(dataloader.dataset) # 测试集的大小
    3. num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)
    4. test_loss, test_acc = 0, 0
    5. # 当不进行训练时,停止梯度更新,节省计算内存消耗
    6. with torch.no_grad():
    7. for imgs, target in dataloader:
    8. imgs, target = imgs.to(device), target.to(device)
    9. # 计算loss
    10. target_pred = model(imgs)
    11. loss = loss_fn(target_pred, target)
    12. test_loss += loss.item()
    13. test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
    14. test_acc /= size
    15. test_loss /= num_batches
    16. return test_acc, test_loss

    3.3 正式训练

    1. import copy
    2. optimizer = torch.optim.Adam(model.parameters(), lr= 1e-4)
    3. loss_fn = nn.CrossEntropyLoss() # 创建损失函数
    4. epochs = 60
    5. train_loss = []
    6. train_acc = []
    7. test_loss = []
    8. test_acc = []
    9. best_acc = 0 # 设置一个最佳准确率,作为最佳模型的判别指标
    10. for epoch in range(epochs):
    11. model.train()
    12. epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    13. model.eval()
    14. epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    15. # 保存最佳模型到 best_model
    16. if epoch_test_acc > best_acc:
    17. best_acc = epoch_test_acc
    18. best_model = copy.deepcopy(model)
    19. train_acc.append(epoch_train_acc)
    20. train_loss.append(epoch_train_loss)
    21. test_acc.append(epoch_test_acc)
    22. test_loss.append(epoch_test_loss)
    23. # 获取当前的学习率
    24. lr = optimizer.state_dict()['param_groups'][0]['lr']
    25. template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    26. print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
    27. epoch_test_acc*100, epoch_test_loss, lr))
    28. # 保存最佳模型到文件中
    29. PATH = './best_model.pth' # 保存的参数文件名
    30. torch.save(best_model.state_dict(), PATH)
    31. print('Done')

    运行结果: 

    1. cpu
    2. cpu
    3. Epoch: 1, Train_acc:53.0%, Train_loss:1.110, Test_acc:64.4%, Test_loss:0.690, Lr:1.00E-04
    4. cpu
    5. cpu
    6. Epoch: 2, Train_acc:60.8%, Train_loss:0.851, Test_acc:60.4%, Test_loss:0.783, Lr:1.00E-04
    7. cpu
    8. cpu
    9. Epoch: 3, Train_acc:68.3%, Train_loss:0.723, Test_acc:74.7%, Test_loss:0.628, Lr:1.00E-04
    10. cpu
    11. cpu
    12. Epoch: 4, Train_acc:73.4%, Train_loss:0.634, Test_acc:75.6%, Test_loss:0.455, Lr:1.00E-04
    13. cpu
    14. cpu
    15. Epoch: 5, Train_acc:74.4%, Train_loss:0.598, Test_acc:76.0%, Test_loss:0.554, Lr:1.00E-04
    16. cpu
    17. cpu
    18. Epoch: 6, Train_acc:76.8%, Train_loss:0.578, Test_acc:81.3%, Test_loss:0.403, Lr:1.00E-04
    19. cpu
    20. cpu
    21. Epoch: 7, Train_acc:80.4%, Train_loss:0.480, Test_acc:83.6%, Test_loss:0.359, Lr:1.00E-04
    22. cpu
    23. cpu
    24. Epoch: 8, Train_acc:82.4%, Train_loss:0.450, Test_acc:82.7%, Test_loss:0.423, Lr:1.00E-04
    25. cpu
    26. cpu
    27. Epoch: 9, Train_acc:82.6%, Train_loss:0.403, Test_acc:89.3%, Test_loss:0.275, Lr:1.00E-04
    28. cpu
    29. cpu
    30. Epoch:10, Train_acc:86.8%, Train_loss:0.345, Test_acc:87.6%, Test_loss:0.373, Lr:1.00E-04
    31. cpu
    32. cpu
    33. Epoch:11, Train_acc:87.1%, Train_loss:0.319, Test_acc:91.6%, Test_loss:0.240, Lr:1.00E-04
    34. cpu
    35. cpu
    36. Epoch:12, Train_acc:88.3%, Train_loss:0.296, Test_acc:84.0%, Test_loss:0.408, Lr:1.00E-04
    37. cpu
    38. cpu
    39. Epoch:13, Train_acc:88.6%, Train_loss:0.284, Test_acc:77.8%, Test_loss:0.519, Lr:1.00E-04
    40. cpu
    41. cpu
    42. Epoch:14, Train_acc:91.9%, Train_loss:0.242, Test_acc:89.8%, Test_loss:0.246, Lr:1.00E-04
    43. cpu
    44. cpu
    45. Epoch:15, Train_acc:93.0%, Train_loss:0.193, Test_acc:89.8%, Test_loss:0.316, Lr:1.00E-04
    46. cpu
    47. cpu
    48. Epoch:16, Train_acc:92.2%, Train_loss:0.201, Test_acc:85.3%, Test_loss:0.451, Lr:1.00E-04
    49. cpu
    50. cpu
    51. Epoch:17, Train_acc:90.6%, Train_loss:0.229, Test_acc:86.7%, Test_loss:0.535, Lr:1.00E-04
    52. cpu
    53. cpu
    54. Epoch:18, Train_acc:92.3%, Train_loss:0.198, Test_acc:80.4%, Test_loss:0.586, Lr:1.00E-04
    55. cpu
    56. cpu
    57. Epoch:19, Train_acc:92.6%, Train_loss:0.196, Test_acc:90.2%, Test_loss:0.251, Lr:1.00E-04
    58. cpu
    59. cpu
    60. Epoch:20, Train_acc:94.8%, Train_loss:0.155, Test_acc:90.7%, Test_loss:0.223, Lr:1.00E-04
    61. cpu
    62. cpu
    63. Epoch:21, Train_acc:95.2%, Train_loss:0.132, Test_acc:90.7%, Test_loss:0.282, Lr:1.00E-04
    64. cpu
    65. cpu
    66. Epoch:22, Train_acc:95.9%, Train_loss:0.121, Test_acc:79.6%, Test_loss:0.744, Lr:1.00E-04
    67. cpu
    68. cpu
    69. Epoch:23, Train_acc:96.8%, Train_loss:0.102, Test_acc:92.9%, Test_loss:0.183, Lr:1.00E-04
    70. cpu
    71. cpu
    72. Epoch:24, Train_acc:97.1%, Train_loss:0.083, Test_acc:87.6%, Test_loss:0.380, Lr:1.00E-04
    73. cpu
    74. cpu
    75. Epoch:25, Train_acc:95.8%, Train_loss:0.133, Test_acc:88.9%, Test_loss:0.350, Lr:1.00E-04
    76. cpu
    77. cpu
    78. Epoch:26, Train_acc:97.4%, Train_loss:0.090, Test_acc:91.1%, Test_loss:0.378, Lr:1.00E-04
    79. cpu
    80. cpu
    81. Epoch:27, Train_acc:96.1%, Train_loss:0.118, Test_acc:88.9%, Test_loss:0.420, Lr:1.00E-04
    82. cpu
    83. cpu
    84. Epoch:28, Train_acc:97.7%, Train_loss:0.075, Test_acc:88.4%, Test_loss:0.343, Lr:1.00E-04
    85. cpu
    86. cpu
    87. Epoch:29, Train_acc:97.1%, Train_loss:0.073, Test_acc:89.8%, Test_loss:0.308, Lr:1.00E-04
    88. cpu
    89. cpu
    90. Epoch:30, Train_acc:98.6%, Train_loss:0.048, Test_acc:90.7%, Test_loss:0.283, Lr:1.00E-04
    91. cpu
    92. cpu
    93. Epoch:31, Train_acc:98.4%, Train_loss:0.056, Test_acc:89.8%, Test_loss:0.340, Lr:1.00E-04
    94. cpu
    95. cpu
    96. Epoch:32, Train_acc:97.6%, Train_loss:0.077, Test_acc:90.7%, Test_loss:0.278, Lr:1.00E-04
    97. cpu
    98. cpu
    99. Epoch:33, Train_acc:98.3%, Train_loss:0.051, Test_acc:82.7%, Test_loss:0.511, Lr:1.00E-04
    100. cpu
    101. cpu
    102. Epoch:34, Train_acc:97.6%, Train_loss:0.069, Test_acc:91.6%, Test_loss:0.416, Lr:1.00E-04
    103. cpu
    104. cpu
    105. Epoch:35, Train_acc:97.1%, Train_loss:0.089, Test_acc:89.8%, Test_loss:0.352, Lr:1.00E-04
    106. cpu
    107. cpu
    108. Epoch:36, Train_acc:97.0%, Train_loss:0.078, Test_acc:90.2%, Test_loss:0.337, Lr:1.00E-04
    109. cpu
    110. cpu
    111. Epoch:37, Train_acc:98.2%, Train_loss:0.060, Test_acc:80.0%, Test_loss:0.723, Lr:1.00E-04
    112. cpu
    113. cpu
    114. Epoch:38, Train_acc:98.0%, Train_loss:0.060, Test_acc:91.6%, Test_loss:0.285, Lr:1.00E-04
    115. cpu
    116. cpu
    117. Epoch:39, Train_acc:99.1%, Train_loss:0.024, Test_acc:92.0%, Test_loss:0.376, Lr:1.00E-04
    118. cpu
    119. cpu
    120. Epoch:40, Train_acc:99.2%, Train_loss:0.032, Test_acc:80.0%, Test_loss:0.669, Lr:1.00E-04
    121. cpu
    122. cpu

    四、完整代码

    1. import torch
    2. import torch.nn as nn
    3. import torchvision.transforms as transforms
    4. import torchvision
    5. from torchvision import transforms, datasets
    6. import os,PIL,pathlib,warnings
    7. warnings.filterwarnings("ignore") #忽略警告信息
    8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    9. print(device)
    10. import os,PIL,random,pathlib
    11. def main():
    12. data_dir = 'D:/P8/weather_photos/'
    13. data_dir = pathlib.Path(data_dir)
    14. data_paths = list(data_dir.glob('*'))
    15. classeNames = [str(path).split("\\")[2] for path in data_paths]
    16. print(classeNames)
    17. # 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
    18. train_transforms = transforms.Compose([
    19. transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
    20. # transforms.RandomHorizontalFlip(), # 随机水平翻转
    21. transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    22. transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
    23. mean=[0.485, 0.456, 0.406],
    24. std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
    25. ])
    26. test_transform = transforms.Compose([
    27. transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
    28. transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    29. transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
    30. mean=[0.485, 0.456, 0.406],
    31. std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
    32. ])
    33. total_data = datasets.ImageFolder("D:/P8/weather_photos/",transform=train_transforms)
    34. print(total_data)
    35. train_size = int(0.8 * len(total_data))
    36. test_size = len(total_data) - train_size
    37. train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
    38. print(train_dataset, test_dataset)
    39. batch_size = 4
    40. train_dl = torch.utils.data.DataLoader(train_dataset,
    41. batch_size=batch_size,
    42. shuffle=True,
    43. num_workers=1)
    44. test_dl = torch.utils.data.DataLoader(test_dataset,
    45. batch_size=batch_size,
    46. shuffle=True,
    47. num_workers=1)
    48. import torch.nn.functional as F
    49. def autopad(k, p=None): # kernel, padding
    50. # Pad to 'same'
    51. if p is None:
    52. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
    53. return p
    54. class Conv(nn.Module):
    55. # Standard convolution
    56. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
    57. super().__init__()
    58. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)# 定义卷积层
    59. self.bn = nn.BatchNorm2d(c2)
    60. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
    61. def forward(self, x):
    62. return self.act(self.bn(self.conv(x)))
    63. class Bottleneck(nn.Module):
    64. # Standard bottleneck
    65. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
    66. super().__init__()
    67. c_ = int(c2 * e) # hidden channels
    68. self.cv1 = Conv(c1, c_, 1, 1)
    69. self.cv2 = Conv(c_, c2, 3, 1, g=g)
    70. self.add = shortcut and c1 == c2
    71. def forward(self, x):
    72. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
    73. class C3(nn.Module):
    74. # CSP Bottleneck with 3 convolutions
    75. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
    76. super().__init__()
    77. c_ = int(c2 * e) # hidden channels
    78. self.cv1 = Conv(c1, c_, 1, 1)
    79. self.cv2 = Conv(c1, c_, 1, 1)
    80. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
    81. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
    82. def forward(self, x):
    83. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
    84. class SPPF(nn.Module):
    85. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
    86. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
    87. super().__init__()
    88. c_ = c1 // 2 # hidden channels
    89. self.cv1 = Conv(c1, c_, 1, 1)
    90. self.cv2 = Conv(c_ * 4, c2, 1, 1)
    91. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
    92. def forward(self, x):
    93. x = self.cv1(x)
    94. with warnings.catch_warnings():
    95. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
    96. y1 = self.m(x)
    97. y2 = self.m(y1)
    98. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
    99. """
    100. 这个是YOLOv5, 6.0版本的主干网络,这里进行复现
    101. (注:有部分删改,详细讲解将在后续进行展开)
    102. """
    103. class YOLOv5_backbone(nn.Module):
    104. def __init__(self):
    105. super(YOLOv5_backbone, self).__init__()
    106. self.Conv_1 = Conv(3, 64, 3, 2, 2)
    107. self.Conv_2 = Conv(64, 128, 3, 2)
    108. self.C3_3 = C3(128,128)
    109. self.Conv_4 = Conv(128, 256, 3, 2)
    110. self.C3_5 = C3(256,256)
    111. self.Conv_6 = Conv(256, 512, 3, 2)
    112. self.C3_7 = C3(512,512)
    113. self.Conv_8 = Conv(512, 1024, 3, 2)
    114. self.C3_9 = C3(1024, 1024)
    115. self.SPPF = SPPF(1024, 1024, 5)
    116. # 全连接网络层,用于分类
    117. self.classifier = nn.Sequential(
    118. nn.Linear(in_features=65536, out_features=100),
    119. nn.ReLU(),
    120. nn.Linear(in_features=100, out_features=4)
    121. )
    122. def forward(self, x):
    123. x = self.Conv_1(x)
    124. x = self.Conv_2(x)
    125. x = self.C3_3(x)
    126. x = self.Conv_4(x)
    127. x = self.C3_5(x)
    128. x = self.Conv_6(x)
    129. x = self.C3_7(x)
    130. x = self.Conv_8(x)
    131. x = self.C3_9(x)
    132. x = self.SPPF(x)
    133. x = torch.flatten(x, start_dim=1)
    134. x = self.classifier(x)
    135. return x
    136. device = "cuda" if torch.cuda.is_available() else "cpu"
    137. print("Using {} device".format(device))
    138. model = YOLOv5_backbone().to(device)
    139. print(model)
    140. # 统计模型参数量以及其他指标
    141. import torchsummary as summary
    142. print(summary.summary(model, (3, 224, 224)))
    143. # 训练循环
    144. def train(dataloader, model, loss_fn, optimizer):
    145. size = len(dataloader.dataset) # 训练集的大小
    146. num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)
    147. train_loss, train_acc = 0, 0 # 初始化训练损失和正确率
    148. for X, y in dataloader: # 获取图片及其标签
    149. X, y = X.to(device), y.to(device)
    150. # 计算预测误差
    151. pred = model(X) # 网络输出
    152. loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
    153. # 反向传播
    154. optimizer.zero_grad() # grad属性归零
    155. loss.backward() # 反向传播
    156. optimizer.step() # 每一步自动更新
    157. # 记录acc与loss
    158. train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
    159. train_loss += loss.item()
    160. train_acc /= size
    161. train_loss /= num_batches
    162. return train_acc, train_loss
    163. def test (dataloader, model, loss_fn):
    164. size = len(dataloader.dataset) # 测试集的大小
    165. num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)
    166. test_loss, test_acc = 0, 0
    167. # 当不进行训练时,停止梯度更新,节省计算内存消耗
    168. with torch.no_grad():
    169. for imgs, target in dataloader:
    170. imgs, target = imgs.to(device), target.to(device)
    171. # 计算loss
    172. target_pred = model(imgs)
    173. loss = loss_fn(target_pred, target)
    174. test_loss += loss.item()
    175. test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
    176. test_acc /= size
    177. test_loss /= num_batches
    178. return test_acc, test_loss
    179. import copy
    180. optimizer = torch.optim.Adam(model.parameters(), lr= 1e-4)
    181. loss_fn = nn.CrossEntropyLoss() # 创建损失函数
    182. epochs = 60
    183. train_loss = []
    184. train_acc = []
    185. test_loss = []
    186. test_acc = []
    187. best_acc = 0 # 设置一个最佳准确率,作为最佳模型的判别指标
    188. for epoch in range(epochs):
    189. model.train()
    190. epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    191. model.eval()
    192. epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    193. # 保存最佳模型到 best_model
    194. if epoch_test_acc > best_acc:
    195. best_acc = epoch_test_acc
    196. best_model = copy.deepcopy(model)
    197. train_acc.append(epoch_train_acc)
    198. train_loss.append(epoch_train_loss)
    199. test_acc.append(epoch_test_acc)
    200. test_loss.append(epoch_test_loss)
    201. # 获取当前的学习率
    202. lr = optimizer.state_dict()['param_groups'][0]['lr']
    203. template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    204. print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
    205. epoch_test_acc*100, epoch_test_loss, lr))
    206. # 保存最佳模型到文件中
    207. PATH = './best_model.pth' # 保存的参数文件名
    208. torch.save(best_model.state_dict(), PATH)
    209. print('Done')
    210. if __name__ == '__main__':
    211. main()

    torch.save(model.state_dict(), PATH)torch.save(model, PATH) 之间有很大的区别:

    1. torch.save(model.state_dict(), PATH) 保存的是模型的参数字典(state_dict),而不是整个模型对象。这意味着只有模型的权重和偏置等参数会被保存,而模型的结构、图层和其他属性不会被保存。这种方式通常用于保存和加载模型的参数,而不包括模型的结构。

    2. torch.save(model, PATH) 保存的是整个模型对象,包括模型的结构、图层、参数和其他属性。这意味着整个模型的状态都会被保存,包括模型的结构和权重。这种方式通常用于保存和加载完整的模型,包括模型的结构和参数。

    copy.deepcopy(model) 

    copy.deepcopy(model): 这部分代码使用 Python 的 copy.deepcopy 函数创建了模型的深度拷贝。这意味着它会复制整个模型对象,包括模型的架构、权重、参数等。通常情况下,这用于创建一个独立的模型副本,以便进一步的处理或修改,而不会影响原始模型或其他引用。

  • 相关阅读:
    Springboot、Tomcat+skywalking 链路追踪、日志收集配置
    Go基础语法:指针和make和new
    交叉熵损失函数(CrossEntropy Loss)的原理理解
    VSCode自动分析代码的插件
    vscode 无法使用 compilerPath“D:.../bin/arm-none-eabi-g++.exe”解析配置。
    【Kettle实战】字符串处理及网络请求JSON格式处理
    ES6-扩展运算符“...“
    精通Spring Boot单元测试:构建健壮的Java应用
    实战演练 | 在 MySQL 中选择除了某一列以外的所有列
    第2关:节点删除与创建
  • 原文地址:https://blog.csdn.net/qq_60245590/article/details/133782541