• 【python】pytorch包(第四章)手写数字图像识别


    问题描述:

    给定手写字体的图片,人工智能自动判断这是数字几

    数据来源:

    MNIST数据集

    代码实战:

    Part 1. 准备数据集

    该模块内容完成的功能:

    1. 下载MNIST数据集;
    2. 转换数据格式,使适用于pytorch;
    3. 数据分批;
    4. 将上述功能 API化
    from torch.utils.data import DataLoader
    from torchvision.datasets import MNIST
    from torchvision.transforms import Compose,ToTensor,Normalize
    def get_data(Batch_Size,train=True):
        #train = True则是训练集,否则是测试集
        #以每批Batch_size大小进行数据分批
        transform_fn = Compose([
            ToTensor(), #转张量
            Normalize(mean=(0.1307,),std=(0.3081))#正则化
            #mean和std的形状和通道数相同
        ])
        dataset = MNIST( #数据集类别
            root=r'E:\MNIST数字识别\training', #将数据集存储在路径/files/内
            train =True, #True表示获取的是训练s集,否则获取的是训练集
            download=False,#如果没有下载过数据集,则需要标True下载
            transform= transform_fn #图片处理函数
        )
        data_loader = DataLoader( #分批次的数据集类别
            dataset, #将dataset分批
            batch_size = Batch_Size, #每批包含Batch_size个数据
            shuffle=True #随机打乱
        )
        return data_loader
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    Part 2. 构建模型

    1. 构建模型逻辑

    神经网络结构:
    四层神经网络,输入层->全连接层1->全连接层2->输出层

    参数设定:

    激活函数: relu()

    损失函数: 交叉熵损失函数

    数据形状:

    原始数据:[batch_size,1,28,28] # 原始数据形状

    输入层:[batch_size,1x28x28] # 摊开

    第一层输出:[batch_size,28] # 参数28可以自行修改

    第二层输出:[batch_size,10] # 十个数字,十个类别

    优化器: Adam()

    import torch
    from torch import nn
    import torch.nn.functional as F
    class Number_Identify(nn.Module):
    	def __init__(self):
    		super(Number_Identify,self).__init__() #继承父类init的参数
    		self.fc1 = nn.Linear(1*28*28,28,bias=True) 
    		#第一层神经网络,输入维度为1*28*28,输出维度为28
    		self.fc2 = nn.Linear(28,10,bias=True)
    		#第二层神经网络,输入维度为28,输出维度为10
    	def forward(self,x): #模型输入x,输出out
    		x = x.view(-1,1*28*28) 
    		#view()相当于reshape(),参数为-1表示 根据情况自适应调整
    		x = self.fc1(x) #经过第一层神经网络计算
    		x = F.relu(x) #经过激活函数
    		out = self.fc2(x) #经过第二层神经网络计算
    		return F.log_softmax(out,dim=-1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    2. 模型实例化

    from torch import optim
    model = Number_Identify()#调用模型基底、
    #criterion = nn.CrossEntropyLoss() #损失函数
    optimizer = optim.Adam(model.parameters(),lr=1e-3) #优化器
    
    • 1
    • 2
    • 3
    • 4

    Part 3. 训练模型

    该模块内容完成的功能:

    1. 从MNIST数据集导入训练数据集
    2. 实现训练逻辑
    3. 将上述功能 API化

    构建模型训练的逻辑过程

    #训练函数
    def train(epoch):
    	model.train(mode=True) #当前模型设定为训练模式
    	data_train = get_data(2,train=True) 
    	#获取训练数据,数据按每两个一组分批
    	for idx,(data,target) in enumerate(data_train):
    		optimizer.zero_grad() #清零梯度
    		out = model(data) #向前计算:预测当前数据的结果
    		loss = F.nll_loss(out,target) #计算带权损失
    		loss.backward() #反向传播
    		optimizer.step() #参数更新
    		#训练进度展示:
    		if idx%10000 == 0:
    			print('\t batches[%d/%d],loss:%.6f' % (
    				idx,len(data_train),loss.data
    			)) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    模型训练

    #训练
    training_times = int(input("输入训练次数:"))
    for epoch in range(training_times): #多次训练
        print("Train_epoch[%d/%d]:" % (epoch+1,training_times))
        train(epoch)
        print("=========Train_epoch[%d/%d] finished======" % (epoch+1,training_times))
    print("======训练完成======")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    Part 4. 保存模型

    torch.save(
        model.state_dict(),
        r'E:\AI_Model_save\Number_Identify\model_net.pt'
    )#保存模型
    torch.save(
        optimizer.state_dict(),
        r'E:\AI_Model_save\Number_Identify\model_optimiter.pt'
    )#保存优化器
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    Part 5. 模型使用

    1. 加载模型

    import os
    import torch
    if os.path.exists(r'E:\AI_Model_save\Number_Identify'):       
        model.load_state_dict(torch.load(
            r'E:\AI_Model_save\Number_Identify\model_net.pt'
        )) #加载模型
        optimizer.load_state_dict(torch.load(
            r'E:\AI_Model_save\Number_Identify\model_optimiter.pt'
        ))#加载优化器
        print("======成功调用=======")
    else: 
        print("路径错误")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    2. 模型评估

    编写test函数,实现模型评估的API
    实现功能:批量预测+计算正确率
    API部分

    import numpy as np
    import torch
    def test(model,data_test):
        loss_list = []
        accuracy_list = []
        for idx,(Input,target) in enumerate(data_test):
            with torch.no_grad():#预测状态,不改变梯度参数
                output = model(Input) #批量预测
                cur_loss = F.nll_loss(output,target) #计算损失
                loss_list.append(cur_loss)
                pred = output.max(dim=-1)[-1] #批量预测结果
                accuracy = pred.eq(target).float().mean() #计算准确率
                accuracy_list.append(accuracy)
                #这里计算的是每个batch的正确率与损失
        return np.mean(accuracy_list),np.mean(loss_list)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    导入测试数据 并测试

    data_test = get_data(100,train=False) 
    #读取测试数据集,每100个为一组进行测试
    accuracy,loss = test(model,data_test) #测试数据
    print("准确率:%.2f" % (accuracy*100),"%")
    print("Loss:",loss)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    3. 单图预测【待填】

    After all:

    用pytorch做了这个实战,个人感受是:
    pytorch的数据预处理更加傻瓜式,但伴随的参数也更多,需要熟悉的API也更多,学起来更麻烦一些,成本更高,但使用起来的便利性更好;
    相比之下,keras的数据预处理需要我们自己完成,学起来很简单,但用起来很麻烦(每次都要自己手写数据的预处理)

  • 相关阅读:
    电脑小白快来!这有电脑常见故障解决方法
    【Redis】Zset 有序集合命令
    Java-集合类
    OpenCV实现图像傅里叶变换
    design compiler之设计环境
    基于长短期记忆神经网络的锂电池寿命预测
    HDU 1009 FatMouse‘ Trade (贪心算法)
    分布式架构 --- 分布式锁
    向毕业妥协系列之机器学习笔记:构建ML系统(三)
    conda虚拟环境安装pytorch(gpu版本)纪实
  • 原文地址:https://blog.csdn.net/l961983207/article/details/130895512