• 深度学习实践3:多层感知机


     算法流程

    1. 首先导入了需要使用的库和模块:

      1. import torch
      2. from torch import nn
      3. from main import load_data_fashion_mnist, train_ch3

      这些库和模块包含了构建和训练模型所需的功能。load_data_fashion_mnist, train_ch3两个函数具体可看    深度学习实践2

    2. 定义了一个包含两个全连接层的神经网络模型net

      1. net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
      2. nn.Linear(256, 10))

      这个模型包含一个将输入展平的Flatten层,一个输入维度为784、输出维度为256的全连接层,一个ReLU激活函数,以及一个输入维度为256、输出维度为10的全连接层。

    3. 定义了一个函数init_weights用于初始化模型的权重:

      1. def init_weights(m):
      2. if type(m) == nn.Linear:
      3. nn.init.normal_(m.weight, std=0.01)

      这个函数接收一个模块m,如果模块是nn.Linear类型的,则对其权重进行正态分布初始化。

    4. 使用apply方法将初始化权重的操作应用到模型net的所有模块上:

      net.apply(init_weights)
      

      这样可以确保模型的权重被正确初始化。

    5. 设置了一些训练的超参数:

      batch_size, lr, num_epochs = 256, 0.1, 10
      

      这里设置了批次大小为256,学习率为0.1,迭代周期数为10。

    6. 定义了损失函数loss为交叉熵损失:

      loss = nn.CrossEntropyLoss()
      

      这个损失函数用于计算模型预测结果与真实标签之间的交叉熵损失。

    7. 定义了优化器trainer为随机梯度下降(SGD)优化器:

      trainer = torch.optim.SGD(net.parameters(), lr=lr)
      

      这个优化器用于更新模型的参数,其中net.parameters()返回模型中所有需要学习的参数。

    8. 使用load_data_fashion_mnist函数加载Fashion-MNIST数据集:

      train_iter, test_iter = load_data_fashion_mnist(batch_size)
      

      这里将训练集和测试集的数据加载器分别赋值给train_itertest_iter

    9. 调用train_ch3函数进行模型训练:

      train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
      

      这里传入模型net、训练数据加载器train_iter、测试数据加载器test_iter、损失函数loss、迭代周期数num_epochs和优化器trainer进行训练。

            结果

            

    1. 1 (0.7864819016138712, 0.7459166666666667, 0.7754)
    2. 2 (0.5714084996541341, 0.8120833333333334, 0.7976)
    3. 3 (0.5254668966929118, 0.82495, 0.8052)
    4. 4 (0.501056636873881, 0.8320666666666666, 0.8246)
    5. 5 (0.4861722059249878, 0.8368333333333333, 0.8247)
    6. 6 (0.4742358523050944, 0.8391666666666666, 0.8264)
    7. 7 (0.46462928047180174, 0.84315, 0.8117)
    8. 8 (0.4579755872090658, 0.8445166666666667, 0.8314)
    9. 9 (0.45267214221954344, 0.8464166666666667, 0.8326)
    10. 10 (0.44778603076934814, 0.8480833333333333, 0.8019)

  • 相关阅读:
    C++编程语言的深度解析: 从零开始的学习路线
    idea 中Maven项目转Gradle项目
    Unity程序在VR一体机(Android)上卡死(闪退)后怎么办?——用adb查看android上某Unity app的debug信息
    Django(18):中间件原理和使用
    计算机毕业设计【HTML+CSS+JavaScript服装购物商城】毕业论文源码
    vue小案列(hello world)
    Spring Boot构建框架中尝试连接到表格时出错
    JavaWeb开发——文件上传
    c++ 常见类内的关键字
    LeetCode 热题 HOT 100:回溯专题
  • 原文地址:https://blog.csdn.net/white_0629/article/details/132618896