• pytorch基本操作:使用神经网络进行分类任务


    1.读取Mnist数据

            首先,读取Mnist数据,在深度学习框架中,数据的基本结构是tensor,据需转换成tensor才能参与后续建模训练,可用map函数将数据转换为tensor格式

    1. import torch
    2. x_train, y_train, x_valid, y_valid = map(
    3. torch.tensor, (x_train, y_train, x_valid, y_valid)
    4. )
    5. n, c = x_train.shape
    6. x_train, x_train.shape, y_train.min(), y_train.max()
    7. print(x_train, y_train)
    8. print(x_train.shape)
    9. print(y_train.min(), y_train.max())

     

    2.torch.nn.functional 

            torch.nn.functional中有很多功能, 比如,常见的激活函数、损失函数,一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

    3.创建一个model

    • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
    • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
    • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
    1. from torch import nn
    2. class Mnist_NN(nn.Module):
    3. def __init__(self):
    4. super().__init__()
    5. self.hidden1 = nn.Linear(784, 128)
    6. self.hidden2 = nn.Linear(128, 256)
    7. self.out = nn.Linear(256, 10)
    8. def forward(self, x):
    9. x = F.relu(self.hidden1(x))
    10. x = F.relu(self.hidden2(x))
    11. x = self.out(x)
    12. return x

    打印出来:

     

     通过named_parameters()或者parameters()返回迭代器

    4.使用TensorDataset和DataLoader加载数据 

            TensorDataset:将训练数据的特征和标签组合

            DataLoader:随机读取小批量

     

     5.训练模块

    梯度下降方法和损失函数 

     

    torch默认会叠加梯度,所以结束后需要将梯度置零

      

    • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
    • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
    1. import numpy as np
    2. def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    3. for step in range(steps):
    4. model.train()
    5. for xb, yb in train_dl:
    6. loss_batch(model, loss_func, xb, yb, opt)
    7. model.eval()
    8. with torch.no_grad(): # 验证时不进行梯度下降
    9. losses, nums = zip(
    10. *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
    11. )
    12. val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) # 平均损失
    13. print('当前step:'+str(step), '验证集损失:'+str(val_loss))

     

     

     

     

     

     

     

     

  • 相关阅读:
    【160】相交链表
    【java】【重构二】分模块开发版本锁定以及耦合(打包)实战
    redis教程
    温故而知新——vue常用语法(三)页面 loading&过滤器&列表过渡
    学会开会|成为有连接感组织的重要技能
    【IEEE】CoEx:通过引导成本体积激励的实时立体匹配模型
    OpenMP的调度-以泊松方程求解为例子
    【大数据采集技术与应用】【第一章】【大数据采集技术与应用概述】
    这份Java面试八股文堪称2022最强,让329人成功进入大厂
    kong安装与配置
  • 原文地址:https://blog.csdn.net/qq_52053775/article/details/126102940