• 十八、完整神经网络模型验证步骤


    网络训练好了,需要提供输入进行验证网络模型训练的效果

    一、加载测试数据

    创建python测试文件,beyond_test.py
    保存在dataset文件夹下a文件夹里的1.jpg小狗图片
    在这里插入图片描述

    二、读取测试图片,重新设置模型所规定的大小(32,32),并转为tensor类型数据

    import torchvision
    from PIL import Image
    from torch import nn
    from torchvision import transforms
    
    • 1
    • 2
    • 3
    • 4
    img_path = "./dataset/a/1.jpg"#当前路径下的dataset文件夹下的a文件夹下的1.jpg文件
    img = Image.open(img_path)
    print(img)#
    img = img.convert('RGB')#png为四通道,jpg为三通道,这里只需要保存RGB通道,可以适应png和jpg图片
    
    
    #①剪切尺寸
    trans_resize = transforms.Resize((32,32))
    #②转为tensor类型
    trans_tensor = transforms.ToTensor()
    
    transform = torchvision.transforms.Compose([trans_resize,trans_tensor])
    #Compose参数都是transform对象,且第一个输出必须满足第二个输入
    #trans_resize为Resize对象,最后输出为PIL类型
    #trans_tensor为ToTensor对象,输入为PIL,输出为tensor
    
    img = transform(img)
    img = torch.reshape(img,(1,3,32,32))
    print(img.shape)#torch.Size([1, 3, 32, 32])
    #送入神经网络需要的格式为(batch_size,channel,H,W)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    三、加载模型

    十七、完整神经网络模型训练步骤博文所训练的模型存放在该路径目录下,这里只是训练3次而已,仅作为学习。
    在这里插入图片描述

    class Beyond(nn.Module):
        def __init__(self):
            super(Beyond,self).__init__()
            self.model = nn.Sequential(
                nn.Conv2d(in_channels=3,out_channels=32,kernel_size=(5,5),stride=1,padding=2),
                nn.MaxPool2d(2),
                nn.Conv2d(in_channels=32,out_channels=32,kernel_size=(5,5),stride=1,padding=2),
                nn.MaxPool2d(2),
                nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(5,5),stride=1,padding=2),
                nn.MaxPool2d(2),
                nn.Flatten(),
                nn.Linear(in_features=1024,out_features=64),
                nn.Linear(in_features=64,out_features=10)
            )
    
        def forward(self,input):
            x = self.model(input)
            return x
    
    beyond = torch.load("./beyond/beyond_3.pth",map_location='cpu')
    # map_location='cpu'表示,不管模型是GPU还是CPU进行训练的,此台电脑使用CPU进行验证
    
    # 若模型是在GPU下训练,验证的时候电脑没有GPU,需要指明map_location参数
    # 若不设置map_location,电脑会根据模型来进行选择映射方式,例如模型beyond_3.pth在GPU下训练的,则验证的时候系统会自动调用GPU进行验证
    print(beyond)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    四、模型转换为测试类型

    eval()官网API
    在这里插入图片描述

    beyond.eval()
    
    • 1

    五、将测试图片送入模型

    beyond.eval()
    with torch.no_grad():
        output = beyond(img)
    print(output)
    #tensor([[ 0.7499,  0.5700, -1.3467,  0.4218,  0.0798, -0.1516, -1.3209,  0.1138,  1.2504, -0.6495]])
    
    print(output.argmax(1))#tensor([8])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    六、查看验证结果

    方法一

    CIFAR10官网给的数据可以看出
    在这里插入图片描述
    8对应的是ship
    dog对应的是5

    主要还是模型训练的次数太少了,主要的目的还是学习模型套路。

    方法二

    beyond_train.py文件下,(该文件来自博文十七、完整神经网络模型训练步骤)
    test_data = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
    打个断点,debug运行下
    train_data —> class_to_idx下就存放这类别编号索引
    在这里插入图片描述

    七、完整代码

    import torch
    import torchvision
    from PIL import Image
    from torch import nn
    from torchvision import transforms
    
    
    class Beyond(nn.Module):
        def __init__(self):
            super(Beyond,self).__init__()
            self.model = nn.Sequential(
                nn.Conv2d(in_channels=3,out_channels=32,kernel_size=(5,5),stride=1,padding=2),
                nn.MaxPool2d(2),
                nn.Conv2d(in_channels=32,out_channels=32,kernel_size=(5,5),stride=1,padding=2),
                nn.MaxPool2d(2),
                nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(5,5),stride=1,padding=2),
                nn.MaxPool2d(2),
                nn.Flatten(),
                nn.Linear(in_features=1024,out_features=64),
                nn.Linear(in_features=64,out_features=10)
            )
    
        def forward(self,input):
            x = self.model(input)
            return x
    
    beyond = torch.load("./beyond/beyond_3.pth")
    print(beyond)
    
    
    img_path = "./dataset/a/1.jpg"#当前路径下的dataset文件夹下的a文件夹下的1.jpg文件
    img = Image.open(img_path)
    print(img)#
    img = img.convert('RGB')#png为四通道,jpg为三通道,这里只需要保存RGB通道,可以适应png和jpg图片
    
    
    #①剪切尺寸
    trans_resize = transforms.Resize((32,32))
    #②转为tensor类型
    trans_tensor = transforms.ToTensor()
    
    transform = torchvision.transforms.Compose([trans_resize,trans_tensor])
    #Compose参数都是transform对象,且第一个输出必须满足第二个输入
    #trans_resize为Resize对象,最后输出为PIL类型
    #trans_tensor为ToTensor对象,输入为PIL,输出为tensor
    
    img = transform(img)
    img = torch.reshape(img,(1,3,32,32))
    print(img.shape)#torch.Size([1, 3, 32, 32])
    #送入神经网络需要的格式为(batch_size,channel,H,W)
    
    beyond.eval()
    with torch.no_grad():
        output = beyond(img)
    print(output)
    #tensor([[ 0.7499,  0.5700, -1.3467,  0.4218,  0.0798, -0.1516, -1.3209,  0.1138,  1.2504, -0.6495]])
    
    print(output.argmax(1))#tensor([8])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
  • 相关阅读:
    单臂路由实现VLAN间路由
    计数的窗口函数应用(2)
    Carla学习笔记(二)服务器跑carla,本地运行carla-ros-bridge并用rviz显示
    Spring Data Envers:使用实体修订进行审计
    题目:2715.执行可取消的延迟函数
    【好文推荐】openGauss 5.0.0 数据库安全——全密态探究
    Nodejs搭建本地http服务器,通过【内网穿透】实现远程访问
    创新案例|实现YouTube超速增长的3大敏捷组织运营机制(上)
    浅谈新生代为什么要分三块区域并且比例为什么是8:1:1
    剑指Offer || :栈与队列(简单)
  • 原文地址:https://blog.csdn.net/qq_41264055/article/details/126499326