• mnist手写数字识别,dnn实现代码解读


    代码及注释?

    # coding: utf-8
    
    '''
    通过dnn识别手写数字集
    '''
    import os
    import sys
    sys.path.append(os.path.abspath(
        os.path.dirname(os.path.abspath(__file__)) + os.path.sep + ".."))
    
    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    from config import MINIST_DATASET
    
    #加载数据集--训练集
    data_train = torchvision.datasets.MNIST(root=MINIST_DATASET,
                                transform=torchvision.transforms.ToTensor(),
                                train=True,
                                download=True)
    #批量化处理数据集,一个批次64个数据--训练集
    loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                    batch_size=64,
                                                    shuffle=True)
    #加载数据集--测试集
    data_test = torchvision.datasets.MNIST(root=MINIST_DATASET,
                                transform=torchvision.transforms.ToTensor(),
                                train=False,
                                download=True)
    #批量化处理数据集,一个批次64个数据--测试集
    loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                                    batch_size=64,
                                                    shuffle=True)
    
    #定义一个网络,继承自pytorch.nn.module
    class Net(nn.Module):
        #继承父亲初始化的方法,并定义两个全链接层,分别是输入784维、输出100维;输入100维输出10维
        def __init__(self):
            super(Net, self).__init__()
            self.line1 = nn.Linear(784, 100)
            self.line2 = nn.Linear(100, 10)
        #定义网络向前传播的结构
        def forward(self, x):
            #将图像的二维数据变换为一 维数据
            x = x.reshape(-1, 784)
            #将数据输入第一层神经网络,并将其输出通过激活函数激活
            x = F.relu(self.line1(x))
            #dropout层
            x = F.dropout(x, 0.2, training=self.training)
            #将数据输入第二层神经网络
            x = self.line2(x)
            #再通过激活函数
            x = F.softmax(x, dim=1)
            #模型输出一个10维的张量
            return x
    
    #模型训练
    def train_model(num, save_model=True):
        #网络实例化
        net = Net()
        #定义优化器,参数调整的方式
        optim = torch.optim.Adam(params=net.parameters())
        #定义损失函数
        loss_function = nn.CrossEntropyLoss()
    
        #开始训练,
        for epoch in range(num):
            #模型训练模式
            net.train()
            running_loss = 0
            #批次训练,每次只有64个图片参与计算
            for data, label in loader_train:
                #计算模型结果
                output = net(data)
                #计算损失值
                loss = loss_function(output, label)
                #梯度清空
                optim.zero_grad()
                #反向传播,计算每个参数的梯度
                loss.backward()
                #使用优化器调整参数
                optim.step()
                #计算一个epoch的损失值
                running_loss += loss.item()
            print("loss:{}".format(running_loss))
    
            #模型验证模式
            net.eval()
            test_correct = 0
            #批量化验证
            for data, label in loader_test:
                #使用模型计算结果
                output = net(data)
                #取出10维数据中最大值的索引,即位预测结果
                _, output = output.max(dim=1)
                #计算准确率
                test_correct += (label == output).sum().item()
            print("正确率:{}%".format(round(test_correct/100.0, 2)))
    
        if save_model:
            # 保存模型
            torch.save(net, os.path.join(base_dir, 'test.pth.tar'))
    
    
    def use_save_model():
        # 加载模型
        net = torch.load(os.path.join(base_dir, 'test.pth.tar'))
        test_correct = 0
        for data, label in loader_test:
            # 使用模型计算结果
            output = net(data)
            # 取出10维数据中最大值的索引,即位预测结果
            _, output = output.max(dim=1)
            #计算预测正确率
            test_correct += (label == output).sum().item()
        print("存储的模型分类正确率:{}%".format(round(test_correct / 100.0, 2)))
    
    
    if __name__ == '__main__':
        train_model(20, False)
    

    模型结构

    在这里插入图片描述

    相关问题

    net.train() 和 net.eval()的作用?

    参考:net.train() 和 net.eval()的作用

    为什么是output.max(1)

    参考:output.max(1)

    optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用

    参考:optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用

  • 相关阅读:
    vue3 如何国际化
    21、wpf之绑定使用小记
    提交本地项目到GitHub
    leetcode 阶乘尾数
    vue3后台管理系统之路由守卫
    HTTP协议
    剑指 Offer 40. 最小的k个数【查找排序】
    python_爬虫
    【算法专题--链表】反转链表II--高频面试题(图文详解,小白一看就会!!!)
    【Unity】思考方式与构造 | 碰撞器/刚体/预设/组件
  • 原文地址:https://blog.csdn.net/siper12138/article/details/126954032