• 【神经网络】【GoogleNet】


    1、引言

    卷积神经网络是当前最热门的技术,我想深入地学习这门技术,从他的发展历史开始,了解神经网络算法的兴衰起伏;同时了解他在发展过程中的**里程碑式算法**,能更好的把握神经网络发展的未来趋势,了解神经网络的特征。
    之前的LeNet为以后的神经网络模型打下了一个基础的框架,真正让神经网络模型在外界广泛引起关注的还是AlexNet,在AlexNet之后也出现了相应对他的改进,或多或少会有一些效果。但是ZFNet是在AlexNet上的改进,他的论文对神经网络的各个层级的作用,做了十分详细的阐述,为如何优化模型,怎样“有理有据”地调节参数指出了一个方向,这是他最大的一个贡献。VggNet也是在2014年的ImageNet的定位赛和分类赛上获得了第一名和第二名,他在原来卷积神经网络结构的基础上,大大增加了网络的深度,最后取得了不错的成绩。GoogleNet同样是在2014年诞生的,他在ImageNet大规模视觉识别挑战赛(ILSVRC14)上提出了一种代号为Inception的深度卷积神经网络结构,并在分类和检测上取得了新的最好结果。
    GoogleNet论文原文下载地址
    GoogleNet论文中文详解

    2、GoogleNet提出背景

    在LeNet之后,随着计算能力的提升,研究者不断改进模型的表达能力,最明显的是卷积层数的增加,每一层的内部通道数目也在增加。其中最具代表性的就是AlexNet模型。
    AlexNet由Hinton和他的学生Alex Krizhevsky设计,模型名字来源于论文第一作者的姓名Alex。该模型以很大的优势获得了2012年ISLVRC竞赛的冠军网络,分类准确率由传统的 70%+提升到 80%+,自那年之后,深度学习开始迅速发展。GoogLeNet是2014年Christian Szegedy提出的一种全新的深度学习结构,在这之前的AlexNet、VGG等结构都是通过增大网络的深度(层数)来获得更好的训练效果,但层数的增加会带来很多负作用,比如overfit、梯度消失、梯度爆炸等。inception的提出则从另一种角度来提升训练结果:能更高效的利用计算资源,在相同的计算量下能提取到更多的特征,从而提升训练结果。值得注意的是GoogLeNet的参数为500w个(5M),VGG16的参数是138M,在表现接近的情况下,GoogLeNet的参数量有明显的优势。
    注:

    ImageNet是一个在2009年创建的图像数据集,从2010年开始到2017年举办了七届的ImageNet 挑战赛——ImageNet
    Large Scale Visual Recognition ChallengeI (LSVRC),在这个挑战赛上诞生了AlexNet、ZFNet、OverFeat、VGG、Inception、ResNet、WideResNet、FractalNet、DenseNet、ResNeXt、DPN、SENet 等经典模型。

    一般来说,提升网络性能最直接的办法就是增加网络深度和宽度,深度指网络层次数量、宽度指神经元数量。但这种方式存在以下问题:

    • (1)参数太多,如果训练数据集有限,很容易产生过拟合;
    • (2)网络越大、参数越多,计算复杂度越大,难以应用;
    • (3)网络越深,容易出现梯度弥散问题(梯度越往后穿越容易消失),难以优化模型。

    所以,有人调侃“深度学习”其实是“深度调参”。
    解决这些问题的方法当然就是在增加网络深度和宽度的同时减少参数,为了减少参数,自然就想到将全连接变成稀疏连接。但是在实现上,全连接变成稀疏连接后实际计算量并不会有质的提升,因为大部分硬件是针对密集矩阵计算优化的,稀疏矩阵虽然数据量少,但是计算所消耗的时间却很难减少。
    那么,有没有一种方法既能保持网络结构的稀疏性,又能利用密集矩阵的高计算性能。大量的文献表明可以将稀疏矩阵聚类为较为密集的子矩阵来提高计算性能,就如人类的大脑是可以看做是神经元的重复堆积,因此,GoogLeNet团队提出了Inception网络结构,就是构造一种“基础神经元”结构,来搭建一个稀疏性、高计算性能的网络结构。

    3、GoogleNet的模型详解

    在这里插入图片描述
    GoogLeNet的网络结构图细节如下:
    在这里插入图片描述GoogLeNet网络结构明细表解析如下:

    • 0、输入

    原始输入图像为224x224x3,且都进行了零均值化的预处理操作(图像每个像素减去均值)。

    • 1、第一模块

    第一模块采用的是一个单纯的卷积层紧跟一个最大池化层。
    卷积层:卷积核大小77,步长为2,padding为3,输出通道数64,输出特征图尺寸为(224-7+32)/2+1=112.5(向下取整)=112,输出特征图维度为112x112x64,卷积后进行ReLU操作。
    池化层:窗口大小33,步长为2,输出特征图尺寸为((112 -3)/2)+1=55.5(向上取整)=56,输出特征图维度为56x56x64。

    • 2、第二模块

    第二模块采用2个卷积层,后面跟一个最大池化层。

    在这里插入图片描述卷积层:(1)先用64个1x1的卷积核(3x3卷积核之前的降维)将输入的特征图(56x56x64)变为56x56x64,然后进行ReLU操作。参数量是116464=4096。
    (2)再用卷积核大小33,步长为1,padding为1,输出通道数192,进行卷积运算,输出特征图尺寸为(56-3+12)/1+1=56,输出特征图维度为56x56x192,然后进行ReLU操作。参数量是3364192=110592。第二模块卷积运算总的参数量是110592+4096=114688,即114688/1024=112K。
    池化层: 窗口大小33,步长为2,输出通道数192,输出为((56 - 3)/2)+1=27.5(向上取整)=28,输出特征图维度为28x28x192。

    • 第三模块(Inception 3a层)

    Inception 3a层,分为四个分支,采用不同尺度,图示如下:
    在这里插入图片描述
    再看下表格结构,来分析和计算吧:
    在这里插入图片描述
    (1)使用64个1x1的卷积核,运算后特征图输出为28x28x64,然后RuLU操作。参数量1119264=12288。
    (2)96个1x1的卷积核(3x3卷积核之前的降维)运算后特征图输出为28x28x96,进行ReLU计算,再进行128个3x3的卷积,输出28x28x128。参数量1119296+3396128=129024。
    (3)16个1x1的卷积核(5x5卷积核之前的降维)将特征图变成28x28x16,进行ReLU计算,再进行32个5x5的卷积,输出28x28x32。参数量1119216+551632=15872。
    (4)pool层,使用3x3的核,输出28x28x192,然后进行32个1x1的卷积,输出28x28x32.。总参数量1119232=6144。
    将四个结果进行连接,对这四部分输出结果的第三维并联,即64+128+32+32=256,最终输出28x28x256。总的参数量是12288+129024+15872+6144=163328,即163328/1024=159.5K,约等于159K。

    • 第三模块(Inception 3b层)

    在这里插入图片描述

    Inception 3b层,分为四个分支,采用不同尺度。
    (1)128个1x1的卷积核,然后RuLU,输出28x28x128。
    (2)128个1x1的卷积核(3x3卷积核之前的降维)变成28x28x128,进行ReLU,再进行192个3x3的卷积,输出28x28x192。
    (3)32个1x1的卷积核(5x5卷积核之前的降维)变成28x28x32,进行ReLU,再进行96个5x5的卷积,输出28x28x96。
    (4)pool层,使用3x3的核,输出28x28x256,然后进行64个1x1的卷积,输出28x28x64。
    将四个结果进行连接,对这四部分输出结果的第三维并联,即128+192+96+64=480,最终输出输出为28x28x480。
    Inception 3b和Inception 4a之间有一个最大池化下采样层 窗口大小33,步长为2,输出特征图维度为14x14x480。
    在这里插入图片描述

    • 第四模块(Inception 4a、4b、4c、4e)

    与Inception3a,3b类似
    在这里插入图片描述
    Inception 4e和Inception 5a之间有一个最大池化下采样层 窗口大小33,步长为2,输出特征图维度为7x7x832。
    在这里插入图片描述

    • 第五模块(Inception 5a、5b)

    与Inception3a,3b类似
    在这里插入图片描述

    • 输出层

    在输出层GoogLeNet与AlexNet、VGG采用3个连续的全连接层不同,GoogLeNet采用的是全局平均池化层,得到的是高和宽均为1的卷积层,然后添加丢弃概率为40%的Dropout,输出层激活函数采用的是softmax。

    • 激活函数

    GoogLeNet每层使用的激活函数为ReLU激活函数。

    • 辅助分类器(分别来自于Inception 4a和Inception 4d的输出)

    根据实验数据,发现神经网络的中间层也具有很强的识别能力,为了利用中间层抽象的特征,在某些中间层中添加含有多层的分类器。如下图所示,红色边框内部代表添加的辅助分类器。GoogLeNet中共增加了两个辅助的softmax分支,作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0。二是将中间某一层输出用作分类,起到模型融合作用。最后的loss=loss_2 + 0.3 * loss_1 + 0.3 * loss_0。实际测试时,这两个辅助softmax分支会被去掉。
    (1)辅助分类器的第一层是一个平均池化下采样层,池化核大小为5x5,stride=3。使得(4a)阶段的输出为4×4×512,(4d)的输出为4×4×528。
    (2)第二层是卷积层,卷积核大小为1x1,stride=1,卷积核个数是128。
    (3)第三层是全连接层,节点个数是1024。
    (4)丢弃70%输出的丢弃层。
    (5)第四层是全连接层,节点个数是1000(对应分类的类别个数)。
    在这里插入图片描述

    4、GoogleNet的创新之处

    4.1、提出Inception结构

    Inception的设计原则

    • 逐层构造网络:如果数据集的概率分布能够被一个神经网络所表达,那么构造这个网络的最佳方法是逐层构筑网络,即将上一层高度相关的节点连接在一起。几乎所有效果好的深度网络都具有这一点,不管AlexNet
      VGG堆叠多个卷积,googleNet堆叠多个inception模块,还是ResNet堆叠多个resblock。
    • 稀疏的结构:人脑的神经元连接就是稀疏的,因此大型神经网络的合理连接方式也应该是稀疏的。稀疏的结构对于大型神经网络至关重要,可以减轻计算量并减少过拟合。
      卷积操作(局部连接,权值共享)本身就是一种稀疏的结构,相比于全连接网络结构是很稀疏的。
    • 符合Hebbian原理: Cells that fire together, wire together. 一起发射的神经元会连在一起。
      相关性高的节点应该被连接而在一起。
      inception中 1×1的卷积恰好可以融合三者。我们一层可能会有多个卷积核,在同一个位置但在不同通道的卷积核输出结果相关性极高。一个1×1的卷积核可以很自然的把这些相关性很高,在同一个空间位置,但不同通道的特征结合起来。而其它尺寸的卷积核(3×3,5×5)可以保证特征的多样性,因此也可以适量使用。于是,这就完成了inception module下图的设计初衷:4个分支:
      在这里插入图片描述

    4.2、1*1卷积核降维

    左上图是GoogleNet作者设计的初始inception结构(native inception),其想法是用多个不同类型的卷积核(1 × 1 1\times11×1,3 × 3 3\times33×3,5 × 5 5\times55×5,3 × 3 P o o l 3\times3Pool3×3Pool)堆叠在一起(卷积、池化后的尺寸相同,将通道相加)代替一个3x3的小卷积核,好处是可以使提取出来的特征具有多样化,并且特征之间的co-relationship不会很大,最后用把feature map都concatenate起来使网络做得很宽,然后堆叠Inception Module将网络变深。但仅仅简单这么做会使一层的计算量爆炸式增长
    native inception中所有的卷积核都在上一层的所有输出上来做,而那个5x5的卷积核所需的计算量就太大了,造成了特征图的厚度很大,为了避免这种情况,在3x3前、5x5前、max pooling后分别加上了1x1的卷积核,以起到了降低特征图厚度的作用,这也就形成了Inception v1的网络结构(右上图)。
    假设input feature map的size为28 × 28 × 256 28\times28\times25628×28×256,output feature map的size为28 × 28 × 480 28\times28\times48028×28×480,则native Inception Module的计算量有854M。计算过程如下
    在这里插入图片描述从上图可以看出,计算量主要来自高维卷积核的卷积操作,因而在每一个卷积前先使用1 × 1 1\times11×1卷积核将输入图片的feature map维度先降低,进行信息压缩,在使用3x3卷积核进行特征提取运算,相同情况下,Inception v1的计算量仅为358M。
    在这里插入图片描述
    Inception结构总共有4个分支,输入的feature map并行的通过这四个分支得到四个输出,然后在在将这四个输出在深度维度(channel维度)进行拼接(concate)得到我们的最终输出(注意,为了让四个分支的输出能够在深度方向进行拼接,必须保证四个分支输出的特征矩阵高度和宽度都相同),因此inception结构的参数为:

    • branch1: C o n v 1 × 1 Conv 1\times1Conv1×1, stride=1
    • branch2: C o n v 3 × 3 Conv 3\times3Conv3×3, stride=1, padding=1
    • branch3: C o n v 5 × 5 Conv 5\times5Conv5×5, stride=1, padding=2
    • branch4: M a x P o o l 3 × 3 MaxPool 3\times3MaxPool3×3, stride=1,padding=1
      GoogLeNet中使用了9个Inception v1 module,分别被命名为inception(3a)、inception(3b)、inception(4a)、inception(4b)、inception(4c)、inception(4d)、inception(4e)、inception(5a)、inception(5b)。

    4.3、两个辅助分类器帮助训练

    作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0。二是将中间某一层输出用作分类,起到模型融合作用。
      辅助函数Axuiliary Function:从信息流动的角度看梯度消失,因为是梯度信息在BP过程中能量衰减,无法到达浅层区域,因此在中间开个口子,加个辅助损失函数直接为浅层。
      GoogLeNet网络结构中有深层和浅层2个分类器,为了避免梯度消失,两个辅助分类器结构是一模一样的,其组成如下图所示,这两个辅助分类器的输入分别来自Inception(4a)和Inception(4d)。

    辅助分类器的第一层是一个平均池化下采样层,池化核大小为5x5,stride=3;第二层是卷积层,卷积核大小为1x1,stride=1,卷积核个数是128;第三层是全连接层,节点个数是1024;第四层是全连接层,节点个数是1000(对应分类的类别个数)。
      辅助分类器只是在训练时使用,在正常预测时会被去掉。辅助分类器促进了更稳定的学习和更好的收敛,往往在接近训练结束时,辅助分支网络开始超越没有任何分支的网络的准确性,达到了更高的水平。

    4.4、使用平均化吃层(减少模型参数)

    网络最后采用了average pooling(平均池化)来代替全连接层,该想法来自NIN(Network in Network),事实证明这样可以将准确率提高0.6%。

    5、GoogleNet的代码实现

    #class_indices.json
    {
        "0": "daisy",
        "1": "dandelion",
        "2": "roses",
        "3": "sunflowers",
        "4": "tulips"
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    #model.py
    import torch.nn as nn
    import torch
    import torch.nn.functional as F
    
    
    class GoogLeNet(nn.Module):
        def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
            super(GoogLeNet, self).__init__()
            self.aux_logits = aux_logits
    
            self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
            self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    
            self.conv2 = BasicConv2d(64, 64, kernel_size=1)
            self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
            self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    
            self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
            self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
            self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    
            self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
            self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
            self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
            self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
            self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
            self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
    
            self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
            self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
    
            if self.aux_logits:
                self.aux1 = InceptionAux(512, num_classes)
                self.aux2 = InceptionAux(528, num_classes)
    
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.dropout = nn.Dropout(0.4)
            self.fc = nn.Linear(1024, num_classes)
            if init_weights:
                self._initialize_weights()
    
        def forward(self, x):
            # N x 3 x 224 x 224
            x = self.conv1(x)
            # N x 64 x 112 x 112
            x = self.maxpool1(x)
            # N x 64 x 56 x 56
            x = self.conv2(x)
            # N x 64 x 56 x 56
            x = self.conv3(x)
            # N x 192 x 56 x 56
            x = self.maxpool2(x)
    
            # N x 192 x 28 x 28
            x = self.inception3a(x)
            # N x 256 x 28 x 28
            x = self.inception3b(x)
            # N x 480 x 28 x 28
            x = self.maxpool3(x)
            # N x 480 x 14 x 14
            x = self.inception4a(x)
            # N x 512 x 14 x 14
            if self.training and self.aux_logits:    # eval model lose this layer
                aux1 = self.aux1(x)
    
            x = self.inception4b(x)
            # N x 512 x 14 x 14
            x = self.inception4c(x)
            # N x 512 x 14 x 14
            x = self.inception4d(x)
            # N x 528 x 14 x 14
            if self.training and self.aux_logits:    # eval model lose this layer
                aux2 = self.aux2(x)
    
            x = self.inception4e(x)
            # N x 832 x 14 x 14
            x = self.maxpool4(x)
            # N x 832 x 7 x 7
            x = self.inception5a(x)
            # N x 832 x 7 x 7
            x = self.inception5b(x)
            # N x 1024 x 7 x 7
    
            x = self.avgpool(x)
            # N x 1024 x 1 x 1
            x = torch.flatten(x, 1)
            # N x 1024
            x = self.dropout(x)
            x = self.fc(x)
            # N x 1000 (num_classes)
            if self.training and self.aux_logits:   # eval model lose this layer
                return x, aux2, aux1
            return x
    
        def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.normal_(m.weight, 0, 0.01)
                    nn.init.constant_(m.bias, 0)
    
    
    class Inception(nn.Module):
        def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
            super(Inception, self).__init__()
    
            self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
    
            self.branch2 = nn.Sequential(
                BasicConv2d(in_channels, ch3x3red, kernel_size=1),
                BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
            )
    
            self.branch3 = nn.Sequential(
                BasicConv2d(in_channels, ch5x5red, kernel_size=1),
                # 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue
                # Please see https://github.com/pytorch/vision/issues/906 for details.
                BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
            )
    
            self.branch4 = nn.Sequential(
                nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                BasicConv2d(in_channels, pool_proj, kernel_size=1)
            )
    
        def forward(self, x):
            branch1 = self.branch1(x)
            branch2 = self.branch2(x)
            branch3 = self.branch3(x)
            branch4 = self.branch4(x)
    
            outputs = [branch1, branch2, branch3, branch4]
            return torch.cat(outputs, 1)
    
    
    class InceptionAux(nn.Module):
        def __init__(self, in_channels, num_classes):
            super(InceptionAux, self).__init__()
            self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
            self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]
    
            self.fc1 = nn.Linear(2048, 1024)
            self.fc2 = nn.Linear(1024, num_classes)
    
        def forward(self, x):
            # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
            x = self.averagePool(x)
            # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
            x = self.conv(x)
            # N x 128 x 4 x 4
            x = torch.flatten(x, 1)
            x = F.dropout(x, 0.5, training=self.training)
            # N x 2048
            x = F.relu(self.fc1(x), inplace=True)
            x = F.dropout(x, 0.5, training=self.training)
            # N x 1024
            x = self.fc2(x)
            # N x num_classes
            return x
    
    
    class BasicConv2d(nn.Module):
        def __init__(self, in_channels, out_channels, **kwargs):
            super(BasicConv2d, self).__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            x = self.conv(x)
            x = self.relu(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    #predict.py
    import os
    import json
    
    import torch
    from PIL import Image
    from torchvision import transforms
    import matplotlib.pyplot as plt
    
    from model import GoogLeNet
    
    
    def main():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
        data_transform = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
        # load image
        img_path = "../tulip.jpg"
        assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
        img = Image.open(img_path)
        plt.imshow(img)
        # [N, C, H, W]
        img = data_transform(img)
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)
    
        # read class_indict
        json_path = './class_indices.json'
        assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
    
        with open(json_path, "r") as f:
            class_indict = json.load(f)
    
        # create model
        model = GoogLeNet(num_classes=5, aux_logits=False).to(device)
    
        # load model weights
        weights_path = "./googleNet.pth"
        assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
        missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                              strict=False)
    
        model.eval()
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)
            predict_cla = torch.argmax(predict).numpy()
    
        print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                     predict[predict_cla].numpy())
        plt.title(print_res)
        for i in range(len(predict)):
            print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                      predict[i].numpy()))
        plt.show()
    
    
    if __name__ == '__main__':
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    #train.py
    import os
    import sys
    import json
    
    import torch
    import torch.nn as nn
    from torchvision import transforms, datasets
    import torch.optim as optim
    from tqdm import tqdm
    
    from model import GoogLeNet
    
    
    def main():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print("using {} device.".format(device))
    
        data_transform = {
            "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
            "val": transforms.Compose([transforms.Resize((224, 224)),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    
        data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
        image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
        assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
        train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                             transform=data_transform["train"])
        train_num = len(train_dataset)
    
        # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
        flower_list = train_dataset.class_to_idx
        cla_dict = dict((val, key) for key, val in flower_list.items())
        # write dict into json file
        json_str = json.dumps(cla_dict, indent=4)
        with open('class_indices.json', 'w') as json_file:
            json_file.write(json_str)
    
        batch_size = 32
        nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
        print('Using {} dataloader workers every process'.format(nw))
    
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size, shuffle=True,
                                                   num_workers=nw)
    
        validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                                transform=data_transform["val"])
        val_num = len(validate_dataset)
        validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                      batch_size=batch_size, shuffle=False,
                                                      num_workers=nw)
    
        print("using {} images for training, {} images for validation.".format(train_num,
                                                                               val_num))
    
        # test_data_iter = iter(validate_loader)
        # test_image, test_label = test_data_iter.next()
    
        net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
        # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
        # 官方的模型中使用了bn层以及改了一些参数,不能混用
        # import torchvision
        # net = torchvision.models.googlenet(num_classes=5)
        # model_dict = net.state_dict()
        # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
        # pretrain_model = torch.load("googlenet.pth")
        # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
        #             "aux2.fc2.weight", "aux2.fc2.bias",
        #             "fc.weight", "fc.bias"]
        # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
        # model_dict.update(pretrain_dict)
        # net.load_state_dict(model_dict)
        net.to(device)
        loss_function = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.0003)
    
        epochs = 30
        best_acc = 0.0
        save_path = './googleNet.pth'
        train_steps = len(train_loader)
        for epoch in range(epochs):
            # train
            net.train()
            running_loss = 0.0
            train_bar = tqdm(train_loader, file=sys.stdout)
            for step, data in enumerate(train_bar):
                images, labels = data
                optimizer.zero_grad()
                logits, aux_logits2, aux_logits1 = net(images.to(device))
                loss0 = loss_function(logits, labels.to(device))
                loss1 = loss_function(aux_logits1, labels.to(device))
                loss2 = loss_function(aux_logits2, labels.to(device))
                loss = loss0 + loss1 * 0.3 + loss2 * 0.3
                loss.backward()
                optimizer.step()
    
                # print statistics
                running_loss += loss.item()
    
                train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                         epochs,
                                                                         loss)
    
            # validate
            net.eval()
            acc = 0.0  # accumulate accurate number / epoch
            with torch.no_grad():
                val_bar = tqdm(validate_loader, file=sys.stdout)
                for val_data in val_bar:
                    val_images, val_labels = val_data
                    outputs = net(val_images.to(device))  # eval model only have last output layer
                    predict_y = torch.max(outputs, dim=1)[1]
                    acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
    
            val_accurate = acc / val_num
            print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
                  (epoch + 1, running_loss / train_steps, val_accurate))
    
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
    
        print('Finished Training')
    
    
    if __name__ == '__main__':
        main()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132

    6、总结

    书山有路勤为径,学海无涯苦作舟。

    7、参考文章

    7.1(四)卷积神经网络模型之——GoogLeNet
    7.2 GoogLeNet网络详解与模型搭建
    7.3 Google Inception Net论文细读
    7.4 深度学习入门笔记之GoogLeNet网络

  • 相关阅读:
    图论+博弈论上dp:CF536D
    react使用脚手架搭建
    江湖再见,机器视觉兄弟们,我已经提离职了,聪明的机器视觉工程师,离职不亏本!
    OpenSSL加解密算法使用方法
    【深度学习】常见开源框架介绍 || 主流深度学习框架 || Tensorflow || Pytorch
    Git基础(21):GitLab创建组、用户、项目
    PHP_EOL不起作用或者无效的原因
    STM32F4 外部中断的时钟SYSCFG
    常用设计模式
    error while loading shared libraries: libc.so.6 误删除libc.so.6急救办法,
  • 原文地址:https://blog.csdn.net/zwh1298454060/article/details/134087505