码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • pytorch基本操作:使用神经网络进行分类任务


    1.读取Mnist数据

            首先,读取Mnist数据,在深度学习框架中,数据的基本结构是tensor,据需转换成tensor才能参与后续建模训练,可用map函数将数据转换为tensor格式

    1. import torch
    2. x_train, y_train, x_valid, y_valid = map(
    3. torch.tensor, (x_train, y_train, x_valid, y_valid)
    4. )
    5. n, c = x_train.shape
    6. x_train, x_train.shape, y_train.min(), y_train.max()
    7. print(x_train, y_train)
    8. print(x_train.shape)
    9. print(y_train.min(), y_train.max())

     

    2.torch.nn.functional 

            torch.nn.functional中有很多功能, 比如,常见的激活函数、损失函数,一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

    3.创建一个model

    • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
    • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
    • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
    1. from torch import nn
    2. class Mnist_NN(nn.Module):
    3. def __init__(self):
    4. super().__init__()
    5. self.hidden1 = nn.Linear(784, 128)
    6. self.hidden2 = nn.Linear(128, 256)
    7. self.out = nn.Linear(256, 10)
    8. def forward(self, x):
    9. x = F.relu(self.hidden1(x))
    10. x = F.relu(self.hidden2(x))
    11. x = self.out(x)
    12. return x

    打印出来:

     

     通过named_parameters()或者parameters()返回迭代器

    4.使用TensorDataset和DataLoader加载数据 

            TensorDataset:将训练数据的特征和标签组合

            DataLoader:随机读取小批量

     

     5.训练模块

    梯度下降方法和损失函数 

     

    torch默认会叠加梯度,所以结束后需要将梯度置零

      

    • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
    • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
    1. import numpy as np
    2. def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    3. for step in range(steps):
    4. model.train()
    5. for xb, yb in train_dl:
    6. loss_batch(model, loss_func, xb, yb, opt)
    7. model.eval()
    8. with torch.no_grad(): # 验证时不进行梯度下降
    9. losses, nums = zip(
    10. *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
    11. )
    12. val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) # 平均损失
    13. print('当前step:'+str(step), '验证集损失:'+str(val_loss))

     

     

     

     

     

     

     

     

  • 相关阅读:
    线性表
    Centos 7上安装Kubernetes 1.24集群
    “火焰杯”软件测试高校就业选拔赛获奖名单揭晓,河南工业大学人工智能与大数据学院两名学子上榜,奖金2万元!
    angular、 react、vue框架对比
    Windows系统利用cpolar内网穿透搭建Zblog博客网站并实现公网访问内网!
    学习加密(三)spring boot 使用RSA非对称加密,前后端传递参数加解密
    《Java基础知识》Java 内省(Introspector)详解2
    阿里云新加坡主机服务器选择
    flutter开发实战-自定义长按TextField输入框剪切、复制、选择全部菜单AdaptiveTextSelectionToolba样式UI效果
    linux:将进程切换到后台且不退出
  • 原文地址:https://blog.csdn.net/qq_52053775/article/details/126102940
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号