• Pytorch之GoogLeNet图像分类


    • 💂 个人主页:风间琉璃
    • 🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主
    • 💬 如果文章对你有帮助、欢迎关注、点赞、收藏(一键三连)订阅专栏

    目录

    前言

    一、GoogLeNet网络结构

    1.Inception 结构

    (1)Inception v1 

    (2)Inception v2

    (3)Inception v3

    (4)Inception v4

    2.网络模型分析

    (1)输入层

    (2)第一个模块

    (2)第二个模块

    (3)第三个模块 Inception 3a

    (4)第四个模块 Inception 3b 

    (5)输出层 

    (6)辅助分类器

    3.网络创新点

    (1)引入Inception

    (2)1x1卷积核

    (3)辅助分类器

    (4)平均池化层

    二、GoogLeNet实现

    1.构建GoogLeNet网络

    2.加载数据集

    3.训练和测试模型

    三、实现图像分类


    前言

    2014 年,GoogLeNet 和 VGG 是当年 ImageNet 挑战赛 (ILSVRC14) 的双雄,GoogLeNet 获得了第一名、VGG 获得了第二名,这两类模型结构的共同特点是层次更深了。VGG 继承了 LeNet 以及 AlexNet 的一些框架结构,而 GoogLeNet 则做了更加大胆的网络结构尝试,虽然深度只有 22 层,但大小却比 AlexNet 和 VGG 小很多,GoogleNet 参数为 500 万个,AlexNet 参数个数是 GoogleNet 的 12 倍,VGGNet 参数又是 AlexNet 的 3 倍,因此在内存或计算资源有限时,GoogleNet 是比较好的选择;从模型结果来看,GoogLeNet 的性能却更加优越

    一、GoogLeNet网络结构

    GoogLeNet是google推出的基于Inception模块的深度神经网络模型,在2014年的ImageNet竞赛中夺得了冠军。

    一般来说,提升网络性能最直接的办法就是增加网络深度和宽度,深度指网络层次数量、宽度指神经元数量。但这种方式存在以下问题:

    (1)参数太多,如果训练数据集有限,很容易产生过拟合;
    (2)网络越大、参数越多,计算复杂度越大,难以应用;
    (3)网络越深,容易出现梯度弥散问题(梯度越往后穿越容易消失),难以优化模型。

    解决方法是在增加网络深度和宽度的同时减少参数。为了减少参数,一般将全连接变成稀疏连接。但是在实现上,全连接变成稀疏连接后实际计算量并不会有质的提升,因为大部分硬件是针对密集矩阵计算优化的,稀疏矩阵虽然数据量少,但是计算所消耗的时间却很难减少。

    那么如何既能保持网络结构的稀疏性,又能利用密集矩阵的高计算性能。大量的文献表明可以将稀疏矩阵聚类为较为密集的子矩阵来提高计算性能,就如人类的大脑是可以看做是神经元的重复堆积,因此,GoogLeNet 团队提出了 Inception 网络结构,就是构造一种 “基础神经元” 结构,来搭建一个稀疏性、高计算性能的网络结构。

    它的主要特点是网络不仅有深度,还在横向上具有“宽度”。由于图像信息在空间尺寸上的巨大差异,如何选择合适的卷积核大小来提取特征就显得比较困难了。空间分布范围更广的图像信息适合用较大的卷积核来提取其特征,而空间分布范围较小的图像信息则适合用较小的卷积核来提取其特征。 

    在随后的两年中一直在改进,形成了Inception V2、Inception V3、Inception V4等版本。

     GoogLeNet网络(22层)结构如下:

    1.Inception 结构

    (1)Inception v1 

    通过设计一个稀疏网络结构,但是能够产生稠密的数据,既能增加神经网络表现,又能保证计算资源的使用效率。谷歌提出了最原始 Inception 的基本结构:其主要思想是利用不同大小的卷积核实现不同尺度的感知最后进行融合,可以得到图像更好的表征。

    Inception Module基本组成结构有四个成分:1*1卷积,3*3卷积,5*5卷积,3*3最大池化

    该结构将 CNN 中常用的卷积(1x1,3x3,5x5)、池化操作(3x3)堆叠在一起(卷积、池化后的尺寸相同,将通道相加),一方面增加了网络的宽度,另一方面也增加了网络对尺度的适应性。
    网络卷积层中的网络能够提取输入的每一个细节信息,同时 5x5 的滤波器也能够覆盖大部分接受层的的输入。还可以进行一个池化操作,以减少空间大小,降低过度拟合。在这些层之上,在每一个卷积层后都要做一个 ReLU 操作,以增加网络的非线性特征。

    原始Inception结构存在很严重的问题:

    1. 所有的卷积层(1×1、3×3、5×5)都是直接和输入对接的,因此卷积过程的参数计算量很大;

    2.并行池化层的输出与输入维度相同,在和其他卷积层的输出做连接时,特征图的深度会变得很深,一样会增加很大的计算量。

    为了避免这种情况,在 3x3 前、5x5 前、max pooling 后分别加上了 1x1 的卷积核,以起到了降低特征图厚度的作用,这也就形成了 Inception v1 的网络结构,如下图所示:

     

    1x1 的卷积核作用:

    1x1 卷积的主要目的是为了减少维度,还用于修正线性激活(ReLU)

    假定上一层的特征图尺度为:224×224×128,经过256个5×5卷积核输出后,输出尺寸为:224×224×256,卷积层参数为:128×5×5×256

    如果上一层先通过一个具有32个尺寸为1×1的卷积核后,再经过256个5×5卷积核输出,输出特征图尺寸仍为:224×224×256,但此时卷积层参数量变为了:128×1×1×32+32×5×5×256,大约减少了4倍。

    这就是 Pointwise Convolution,即 1x1 卷积,简写为 PW,主要用于数据降维,减少参数量。当然也有使用 PW 做升维的,在 MobileNet v2 中就使用 PW 将 3 个特征图变成 6 个特征图,丰富输入数据的特征

    (2)Inception v2

    GoogLeNet 凭借其优秀的表现,得到了很多研究人员的学习和使用,因此 GoogLeNet 团队又对其进行了进一步地发掘改进,产生了升级版本的 GoogLeNet。

    但是谷歌团队发现如果一味的堆叠Inception模块虽然对准确率有所提升,但对计算机效率并没有很好提升,反之会有明显下降,因此如何在不增加过多计算量的同时提高网络的表达能力就成为了一个问题。

    Inception V2 版本的解决方案就是修改 Inception 的内部计算逻辑,提出了比较特殊的 “卷积” 计算结构

    1.卷积分解

    大尺寸的卷积核可以带来更大的感受野,但也意味着会产生更多的参数。因此,GoogLeNet 团队提出可以用 2 个连续的 3x3 卷积层组成的小网络来代替单个的 5x5 卷积层,即在保持感受野范围的同时又减少了参数量,如下图:

    并进一步考虑了n×1卷积核,来取代3×3卷积核 。

    任意 nxn 的卷积都可以通过 1xn 卷积后接 nx1 卷积来替代。GoogLeNet 团队发现在网络的前期使用这种分解效果并不好,在中度大小的特征图(feature map)上使用效果才会更好(特征图大小建议在 12 到 20 之间)。 

    Inception模块优化过程:

     2.降低特征图大小

    一般情况下,如果想让图像缩小,可以有如下两种方式:

    方法一(左图):先池化再作 Inception 卷积,或者先作 Inception 卷积再作池化。但是方法一先作 pooling(池化)会导致特征表示遇到瓶颈(特征缺失)。

    方法二(右图)是正常的缩小,但计算量很大。

    为了同时保持特征表示且降低计算量,将网络结构改为下图,使用两个并行化的模块来降低计算量(卷积、池化并行执行,再进行合并) 。

    以上所有的方式方法的融合就得到了Inception v2。

    (3)Inception v3

    Inception V3结构较V2并没有太多改进,主要有一下几点:

    • 对7×7卷积层分解为两个一维卷积(1×7,7×1),3x3也一样
    • 对损失函数添加正则项,避免在分类网络中,神经网络对某一类别具有高度拟合性;
    • 辅助分类器中也使用了BN。

    分解既可以加速计算,又可以将 1 个卷积拆成 2 个卷积,使得网络深度进一步增加,增加了网络的非线性(每增加一层都要进行 ReLU)。 

    (4)Inception v4

    Inception V4 研究了 Inception 模块与残差连接的结合。ResNet 结构大大地加深了网络深度,还极大地提升了训练速度,同时性能也有提升。
    Inception V4 主要利用残差连接(Residual Connection)来改进 V3 结构,得到 Inception-ResNet-v1,Inception-ResNet-v2,Inception-v4 网络。

    ResNet 的残差结构和Inception-ResNet如下所示:

    通过 20 个类似的模块组合,Inception-ResNet 构建如下:

    2.网络模型分析

    基于 Inception 构建了 GoogLeNet 的网络结构如下(共 22 层):主要由9个 I n c e p t i o n InceptionInception 块、全局平均汇聚层、辅助分类器构成。
     

    1. GoogLeNet 采用了模块化的结构(Inception 结构),方便增添和修改。


    2.网络最后采用 average pooling(平均池化)来代替全连接层,在最后还是加了一个全连接层,主要是为了方便对输出进行灵活调整。


    3.虽然移除了全连接,但是网络中依然使用了 Dropout。


    4.为了避免梯度消失,网络额外增加了 2 个辅助的 softmax 用于向前传导梯度(辅助分类器)。

    辅助分类器是将中间某一层的输出用作分类,并按一个较小的权重(0.3)加到最终分类结果中,这样相当于做了模型融合,同时给网络增加了反向传播的梯度信号,也提供了额外的正则化,对于整个网络的训练很有裨益。而在实际测试的时候,这两个额外的 softmax 会被去掉。 

    GoogLeNet 的网络结构图细节如下: 

    列名
    type网络名称
    patch size/stride网络参数,卷积核大小/stride
    output size输出特征矩阵的大小
    depth对应该行结构的数量,如第三行卷积层,depth=2,表示经过两层卷积层,先是1x1,然后3x3
    后8列关于Inception结构的配置

    上表中的 “#3x3 reduce”,“#5x5 reduce” 表示在 3x3,5x5 卷积操作之前使用了 1x1 卷积的数量。"pool proj"表示在池化层后使用1x1卷积的数量。

    (1)输入层

    原始输入图像为 224x224x3,且都进行了零均值化的预处理操作(图像每个像素减去均值)。

    (2)第一个模块

    处理流程:卷积-->ReLU-->池化

    卷积层:卷积核大小7*7,步长为2,padding为3,输出通道数64,输出特征图尺寸为(224-7+3*2)/2+1=112.5(向下取整)=112,输出特征图维度为112x112x64,卷积后进行ReLU操作。

    池化层:窗口大小3*3,步长为2,输出特征图尺寸为((112 -3)/2)+1=55.5(向上取整)=56,输出特征图维度为56x56x64。

    (2)第二个模块

    处理流程:卷积-->卷积-->ReLU-->池化

    卷积层:先用64个1x1的卷积核(3x3卷积核之前的降维)将输入的特征图(56x56x64)变为56x56x64,然后进行ReLU操作。
    再用卷积核大小3*3,步长为1,padding为1,输出通道数192,进行卷积运算,输出特征图尺寸为(56-3+1*2)/1+1=56,输出特征图维度为56x56x192,然后进行ReLU操作。

    池化层: 窗口大小3*3,步长为2,输出通道数192,输出为((56 - 3)/2)+1=27.5(向上取整)=28,输出特征图维度为28x28x192。



    (3)第三个模块 Inception 3a


    Inception 3a层分为四个分支,采用不同尺度的卷积核来进行处理。


    (1)64 个 1x1 的卷积核,然后 RuLU,输出 28x28x64
    (2)96 个 1x1 的卷积核,作为 3x3 卷积核之前的降维,变成 28x28x96,然后进行 ReLU 计算,再进行 128 个 3x3 的卷积(padding 为 1),输出 28x28x128
    (3)16 个 1x1 的卷积核,作为 5x5 卷积核之前的降维,变成 28x28x16,进行 ReLU 计算后,再进行 32 个 5x5 的卷积(padding 为 2),输出 28x28x32
    (4)pool 层,使用 3x3 的核(padding 为 1),输出 28x28x192,然后进行 32 个 1x1 的卷积,输出 28x28x32。
    将四个结果进行连接,对这四部分输出结果的第三维并联,即 64+128+32+32=256,最终输出 28x28x256


    (4)第四个模块 Inception 3b 


    (1)128 个 1x1 的卷积核,然后 RuLU,输出 28x28x128
    (2)128 个 1x1 的卷积核,作为 3x3 卷积核之前的降维,变成 28x28x128,进行 ReLU,再进行 192 个 3x3 的卷积(padding 为 1),输出 28x28x192
    (3)32 个 1x1 的卷积核,作为 5x5 卷积核之前的降维,变成 28x28x32,进行 ReLU 计算后,再进行 96 个 5x5 的卷积(padding 为 2),输出 28x28x96
    (4)pool 层,使用 3x3 的核(padding 为 1),输出 28x28x256,然后进行 64 个 1x1 的卷积,输出 28x28x64。
    将四个结果进行连接,对这四部分输出结果的第三维并联,即 128+192+96+64=480,最终输出输出为 28x28x480

    第四层(4a,4b,4c,4d,4e)、第五层(5a,5b)……,与 3a、3b 类似,在此就不再重复。

    (5)输出层 

    在输出层GoogLeNet与AlexNet、VGG采用3个连续的全连接层不同,GoogLeNet采用的是全局平均池化层,得到的是高和宽均为1的卷积层,然后添加丢弃概率为40%的Dropout,输出层激活函数采用的是softmax。 

    (6)辅助分类器

    根据实验数据,发现神经网络的中间层也具有很强的识别能力,为了利用中间层抽象的特征,在某些中间层中添加含有多层的分类器

    如下图所示,红色边框内部代表添加的辅助分类器。GoogLeNet中共增加了两个辅助的softmax分支,作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0二是将中间某一层输出用作分类,起到模型融合作用。最后的loss=loss_2 + 0.3 * loss_1 + 0.3 * loss_0。实际测试时,这两个辅助softmax分支会被去掉。

    3.网络创新点

    (1)引入Inception

    引入Inception结构,融合不同尺度的特征信息,能得到更好的特征表征。更意味着提高准确率,不一定需要堆叠更深的层或者增加神经元个数等,可以转向研究更稀疏但是更精密的结构同样可以达到很好的效果。

    (2)1x1卷积核

    使用1x1的卷积核进行降维以及映射处理。

    (3)辅助分类器

    添加两个辅助分类器帮助训练,在 GoogLeNet(Inception 网络)中,辅助分类器(Auxiliary Classifier)是一种用于训练过程中的辅助分类器,它有助于解决深度神经网络中的梯度消失问题(vanishing gradient problem)并加速训练。辅助分类器的作用如下:

    1. 缓解梯度消失问题:深度神经网络通常有很多层,而反向传播中的梯度在深度网络中可能会逐渐变得非常小,导致训练变得困难。辅助分类器通过在网络中间添加一个额外的分类器,可以提供额外的梯度信号,帮助在训练过程中传播梯度,从而缓解梯度消失问题。

    2. 正则化:辅助分类器可以看作是一种正则化技术。它强制网络中间的特征图具有一定的分类能力,因为这些特征图需要用于中间的分类任务。这有助于网络学习更具有区分性的特征。

    3. 多尺度特征:辅助分类器通常在网络的中间层添加,这使得它可以从中间层获取多尺度的特征表示。这些多尺度的特征可以对不同尺度的对象进行分类,有助于提高模型的分类性能。

    4. 减少过拟合:辅助分类器引入了额外的分类任务,可以视为一种正则化方法,有助于减少过拟合的风险,尤其是在训练数据较少的情况下。

    需要注意的是,辅助分类器通常在训练过程中使用,而在推断(inference)阶段时通常不使用它们。在推断阶段,主要的分类器负责最终的分类任务。在训练过程中,辅助分类器的预测结果与主分类器的结果一起被用于计算损失函数,以帮助网络更好地训练。 

    (4)平均池化层

    丢弃全连接层,使用平均池化层(大大减少模型参数)

    二、GoogLeNet实现

    1.构建GoogLeNet网络

    由于GoogLeNet网络中有大量的重复模块,我们可以将重复的模块单独定义,方便堆叠模块。

    首先是卷积层模块,一般处理流程:卷积-->ReLU

    1. # 卷积层基础模块:卷积 + ReLU
    2. class BasicConv2d(nn.Module):
    3. def __init__(self, in_channels, out_channels, **kwargs):
    4. super(BasicConv2d, self).__init__()
    5. self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwargs)
    6. self.relu = nn.ReLU(inplace=True)
    7. def forward(self, x):
    8. x = self.conv(x)
    9. x = self.relu(x)
    10. return x

    然后就是GoogLeNet的核心模块Inception模块,主要依据网络结构图搭建该模块,一个输入一个输出,中间含有4条分支,然后在维度上进行拼接,

    1. # Inception模块
    2. class Inception(nn.Module):
    3. def __init__(self, in_channels, ch1x1, ch3x3reduce, ch3x3, ch5x5reduce, ch5x5, pool_proj):
    4. super(Inception, self).__init__()
    5. # 分支1:1x1卷积
    6. self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
    7. # 分支2:1x1卷积 + 3x3卷积
    8. self.branch2 = nn.Sequential(
    9. BasicConv2d(in_channels, ch3x3reduce, kernel_size=1),
    10. BasicConv2d(ch3x3reduce, ch3x3, kernel_size=3, padding=1) # 保证输出大小等于输入大小
    11. )
    12. # 分支3:1x1卷积 + 5x5卷积
    13. self.branch3 = nn.Sequential(
    14. BasicConv2d(in_channels, ch5x5reduce, kernel_size=1),
    15. BasicConv2d(ch5x5reduce, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小
    16. )
    17. # 分支4:池化 + 3x3卷积
    18. self.branch4 = nn.Sequential(
    19. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
    20. BasicConv2d(in_channels, pool_proj, kernel_size=1)
    21. )
    22. def forward(self, x):
    23. branch1 = self.branch1(x)
    24. branch2 = self.branch2(x)
    25. branch3 = self.branch3(x)
    26. branch4 = self.branch4(x)
    27. outputs = [branch1, branch2, branch3, branch4]
    28. return torch.cat(outputs, 1) # 拼接

     最后还有两个辅助分类器,其输入层分别为4a,4d Inception模块的输出。

    1. # 辅助分类器
    2. class InceptionAux(nn.Module):
    3. def __init__(self, in_channels, num_classes):
    4. super(InceptionAux, self).__init__()
    5. self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
    6. self.conv = BasicConv2d(in_channels, 128, kernel_size=1) # output[batch, 128, 4, 4]
    7. self.fc1 = nn.Linear(2048, 1024)
    8. self.fc2 = nn.Linear(1024, num_classes)
    9. def forward(self, x):
    10. # 辅助分类器1:Nx512x14x14 辅助分类器2:Nx528x14x14
    11. x = self.averagePool(x)
    12. # 辅助分类器1:Nx512x4x4 辅助分类器2:Nx528x4x4
    13. x = self.conv(x)
    14. # Nx128x4x4
    15. x = torch.flatten(x, 1)
    16. x = F.dropout(x, p=0.5, training=self.training) # 训练模型:self.training=True, 测试模型:self.training=False
    17. # Nx2048
    18. x = F.relu(self.fc1(x), inplace=True)
    19. x = F.dropout(x, p=0.5, training=self.training)
    20. # Nx1024
    21. x = self.fc2(x)
    22. # N x num_classes
    23. return x

     根据以上模块搭建GoogLeNet网络模型,其中有些参数需要根据以下的表格获取。

    1. # GoogLeNet网络
    2. class GoogLeNet(nn.Module):
    3. def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
    4. super(GoogLeNet, self).__init__()
    5. self.aux_logits = aux_logits
    6. self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
    7. self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    8. # 这里无nn.LocalResponseNorm(),可自行添加
    9. self.conv2 = BasicConv2d(64, 64, kernel_size=1)
    10. self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
    11. self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) # ceil_mode:向上取整
    12. # 查表可得inception的配置参数
    13. self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
    14. self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
    15. self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    16. self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
    17. self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
    18. self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
    19. self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
    20. self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
    21. self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    22. self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
    23. self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
    24. # 是否使用辅助分类器
    25. if self.aux_logits:
    26. self.aux1 = InceptionAux(512, num_classes)
    27. self.aux2 = InceptionAux(528, num_classes)
    28. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    29. self.dropout = nn.Dropout(0.4)
    30. self.fc = nn.Linear(1024, num_classes)
    31. if init_weights:
    32. self._initialize_weights()
    33. def forward(self, x):
    34. # N x 3 x 224 x 224
    35. x = self.conv1(x)
    36. # N x 64 x 112 x 112
    37. x = self.maxpool1(x)
    38. # N x 64 x 56 x 56
    39. x = self.conv2(x)
    40. # N x 64 x 56 x 56
    41. x = self.conv3(x)
    42. # N x 192 x 56 x 56
    43. x = self.maxpool2(x)
    44. # N x 192 x 28 x 28
    45. x = self.inception3a(x)
    46. # N x 256 x 28 x 28
    47. x = self.inception3b(x)
    48. # N x 480 x 28 x 28
    49. x = self.maxpool3(x)
    50. # N x 480 x 14 x 14
    51. x = self.inception4a(x)
    52. # N x 512 x 14 x 14
    53. # 训练模型开启辅助分类器1,测试时不使用
    54. if self.training and self.aux_logits: # eval model lose this layer
    55. aux1 = self.aux1(x)
    56. x = self.inception4b(x)
    57. # N x 512 x 14 x 14
    58. x = self.inception4c(x)
    59. # N x 512 x 14 x 14
    60. x = self.inception4d(x)
    61. # N x 528 x 14 x 14
    62. # 训练模型开启辅助分类器2,测试时不使用
    63. if self.training and self.aux_logits: # eval model lose this layer
    64. aux2 = self.aux2(x)
    65. x = self.inception4e(x)
    66. # N x 832 x 14 x 14
    67. x = self.maxpool4(x)
    68. # N x 832 x 7 x 7
    69. x = self.inception5a(x)
    70. # N x 832 x 7 x 7
    71. x = self.inception5b(x)
    72. # N x 1024 x 7 x 7
    73. x = self.avgpool(x)
    74. # N x 1024 x 1 x 1
    75. x = torch.flatten(x, 1)
    76. # N x 1024
    77. x = self.dropout(x)
    78. x = self.fc(x)
    79. # N x 1000 (num_classes)
    80. # 训练模型返回三个值,加权作为最终结果,测试时不使用
    81. if self.training and self.aux_logits: # eval model lose this layer
    82. return x, aux2, aux1
    83. return x
    84. def _initialize_weights(self):
    85. for m in self.modules():
    86. if isinstance(m, nn.Conv2d):
    87. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    88. if m.bias is not None:
    89. nn.init.constant_(m.bias, 0)
    90. elif isinstance(m, nn.Linear):
    91. nn.init.normal_(m.weight, 0, 0.01)
    92. nn.init.constant_(m.bias, 0)

    2.加载数据集

    这里使用花朵数据集,数据集制造和数据集使用的脚本的参考:Pytorch之AlexNet花朵分类_风间琉璃•的博客-CSDN博客

     加载数据集和测试集,并进行相应的预处理操作。

    1. data_transform = {
    2. "train": transforms.Compose([transforms.RandomResizedCrop(224),
    3. transforms.RandomHorizontalFlip(),
    4. transforms.ToTensor(),
    5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    6. "val": transforms.Compose([transforms.Resize((224, 224)),
    7. transforms.ToTensor(),
    8. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    9. # 数据集根目录
    10. data_root = os.path.abspath(os.getcwd())
    11. print(os.getcwd())
    12. # 图片目录
    13. image_path = os.path.join(data_root, "data_set", "flower_data")
    14. print(image_path)
    15. assert os.path.exists(image_path), "{} path does not exit.".format(image_path)
    16. # 准备数据集
    17. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
    18. transform=data_transform["train"])
    19. train_num = len(train_dataset)
    20. validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
    21. transform=data_transform["val"])
    22. val_num = len(validate_dataset)
    23. # 定义一个包含花卉类别到索引的字典:雏菊,蒲公英,玫瑰,向日葵,郁金香
    24. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    25. # 获取包含训练数据集类别名称到索引的字典,这通常用于数据加载器或数据集对象中。
    26. flower_list = train_dataset.class_to_idx
    27. # 创建一个反向字典,将索引映射回类别名称
    28. cla_dict = dict((val, key) for key, val in flower_list.items())
    29. # 将字典转换为格式化的JSON字符串,每行缩进4个空格
    30. json_str = json.dumps(cla_dict, indent=4)
    31. # 打开名为 'class_indices.json' 的JSON文件,并将JSON字符串写入其中
    32. with open('class_indices.json', 'w') as json_file:
    33. json_file.write(json_str)
    34. batch_size = 32
    35. # min: CPU 核心数量、批次大小(如果大于1),以及一个最大值8
    36. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
    37. print("using {} dataloader workers every process".format(nw))
    38. # 加载数据集
    39. train_loader = torch.utils.data.DataLoader(train_dataset,
    40. batch_size=batch_size, shuffle=True,
    41. num_workers=nw)
    42. validate_loader = torch.utils.data.DataLoader(validate_dataset,
    43. batch_size=4, shuffle=False,
    44. num_workers=nw)
    45. print("using {} images for training, {} images for validation.".format(train_num, val_num))

    3.训练和测试模型

    数据集预处理完成后,就可以进行网络模型的训练和验证。

    1. net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    2. # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
    3. # 官方的模型中使用了bn层以及改了一些参数,不能混用
    4. # import torchvision
    5. # net = torchvision.models.googlenet(num_classes=5)
    6. # model_dict = net.state_dict()
    7. # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    8. # pretrain_model = torch.load("googlenet.pth")
    9. # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
    10. # "aux2.fc2.weight", "aux2.fc2.bias",
    11. # "fc.weight", "fc.bias"]
    12. # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    13. # model_dict.update(pretrain_dict)
    14. # net.load_state_dict(model_dict)
    15. net.to(device)
    16. loss_function = nn.CrossEntropyLoss()
    17. optimizer = optim.Adam(net.parameters(), lr=0.0003)
    18. epochs = 120
    19. best_acc = 0.0
    20. save_path = './GoogLeNet.pth'
    21. train_steps = len(train_loader)
    22. for epoch in range(epochs):
    23. # 设置为训练模式
    24. net.train()
    25. running_loss = 0.0
    26. train_bar = tqdm(train_loader, file=sys.stdout)
    27. for step, data in enumerate(train_bar):
    28. images, labels = data
    29. optimizer.zero_grad()
    30. logits, aux_logits2, aux_logits1 = net(images.to(device))
    31. # 训练时,损失为3个输出损失的加权
    32. loss0 = loss_function(logits, labels.to(device))
    33. loss1 = loss_function(aux_logits1, labels.to(device))
    34. loss2 = loss_function(aux_logits2, labels.to(device))
    35. loss = loss0 + loss1 * 0.3 + loss2 * 0.3
    36. loss.backward()
    37. optimizer.step()
    38. running_loss += loss.item()
    39. train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
    40. epochs,
    41. loss)
    42. # 设置为测试模式
    43. net.eval()
    44. acc = 0.0
    45. with torch.no_grad():
    46. val_bar = tqdm(validate_loader, file=sys.stdout)
    47. for val_data in val_bar:
    48. val_images, val_labels = val_data
    49. # 测试层仅有最后输出层
    50. outputs = net(val_images.to(device))
    51. predict_y = torch.max(outputs, dim=1)[1]
    52. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    53. val_accurate = acc / val_num
    54. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
    55. (epoch + 1, running_loss / train_steps, val_accurate))
    56. if val_accurate > best_acc:
    57. best_acc = val_accurate
    58. torch.save(net.state_dict(), save_path)
    59. print('Finished Training')

    训练120epoch的准确率能到达80%左右。

    三、实现图像分类

    利用上述训练好的网络模型进行测试,验证是否能完成分类任务。

    报错:注意这里加载模型的时候只需要加载主干网络的权重文件,不需要辅助分类器的相关文件。

    加载模型文件如下:

    1. # 加载模型文件
    2. weights_path = "./GoogLeNet.pth"
    3. assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    4. # strict=False 表示在加载权重时允许不匹配的键,如果预训练权重文件中的一些权重参数与当前模型不完全匹配,也不会引发错误
    5. # missing_keys包含了在权重文件中存在但模型中不存在的键
    6. # unexpected_key包含了在模型中存在但权重文件中不存在的键
    7. missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device), strict=False)
    8. # model.load_state_dict(torch.load(weights_path))

    RuntimeError: Error(s) in loading state_dict for GoogLeNet:
        Unexpected key(s) in state_dict: "aux1.conv.conv.weight", "aux1.conv.conv.bias", "aux1.fc1.weight", "aux1.fc1.bias", "aux1.fc2.weight", "aux1.fc2.bias", "aux2.conv.conv.weight", "aux2.conv.conv.bias", "aux2.fc1.weight", "aux2.fc1.bias", "aux2.fc2.weight", "aux2.fc2.bias". 

    1. import os
    2. import json
    3. import torch
    4. from PIL import Image, ImageDraw
    5. from torchvision import transforms
    6. from model import GoogLeNet
    7. def main():
    8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    9. data_transform = transforms.Compose([
    10. transforms.Resize((224, 224)),
    11. transforms.ToTensor(),
    12. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    13. ])
    14. # 加载图片
    15. img_path = 'daisy.jpg'
    16. assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)
    17. image = Image.open(img_path)
    18. # img.show()
    19. image.show()
    20. # [N, C, H, W]
    21. img = data_transform(image)
    22. # 扩展维度
    23. img = torch.unsqueeze(img, dim=0)
    24. # 获取标签
    25. json_path = 'class_indices.json'
    26. assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
    27. with open(json_path, 'r') as f:
    28. # 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中
    29. class_indict = json.load(f)
    30. # 加载网络
    31. model = GoogLeNet(num_classes=5, aux_logits=False).to(device)
    32. # 加载模型文件
    33. weights_path = "./GoogLeNet.pth"
    34. assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    35. # strict=False 表示在加载权重时允许不匹配的键,如果预训练权重文件中的一些权重参数与当前模型不完全匹配,也不会引发错误
    36. # missing_keys包含了在权重文件中存在但模型中不存在的键
    37. # unexpected_key包含了在模型中存在但权重文件中不存在的键
    38. missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device), strict=False)
    39. # model.load_state_dict(torch.load(weights_path))
    40. model.eval()
    41. with torch.no_grad():
    42. # 对输入图像进行预测
    43. output = torch.squeeze(model(img.to(device))).cpu()
    44. # 对模型的输出进行 softmax 操作,将输出转换为类别概率
    45. predict = torch.softmax(output, dim=0)
    46. # 得到高概率的类别的索引
    47. predict_cla = torch.argmax(predict).numpy()
    48. res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
    49. draw = ImageDraw.Draw(image)
    50. # 文本的左上角位置
    51. position = (10, 10)
    52. # fill 指定文本颜色
    53. draw.text(position, res, fill='red')
    54. image.show()
    55. for i in range(len(predict)):
    56. print("class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
    57. if __name__ == '__main__':
    58. main()

    运行结果:

     

    结束语

    感谢阅读吾之文章,今已至此次旅程之终站 🛬。

    吾望斯文献能供尔以宝贵之信息与知识也 🎉。

    学习者之途,若藏于天际之星辰🍥,吾等皆当努力熠熠生辉,持续前行。

    然而,如若斯文献有益于尔,何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也 💞。

  • 相关阅读:
    odoo 视图部分详解(四)
    图神经网络的基本知识
    隆云通吸顶多参数传感器
    Qt的WebEngineView加载网页时出现Error: WebGL is not supported
    Postgresql更改字段默认值、设置字段默认值、删除字段默认值
    git reset 和 git revert的使用
    ISO27001认证办理流程及2022年补贴政策汇总
    二十八、CANdelaStudio实践-10服务(SessionControl)
    前端构建工具总结
    【SQL】索引的创建与设计原则
  • 原文地址:https://blog.csdn.net/qq_53144843/article/details/133279746