码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 【PyTorch】深度学习实践之反向传播 Back Propagation


    本文目录

    • 前馈计算
    • 反向传播过程
    • Tensor in PyTorch
    • 课堂练习:线性模型 Linear Model
      • 实现代码
      • 结果
    • 课后练习
    • 学习资料
    • 系列文章索引

    前馈计算

    权重维度增加,层数增加,模型变得复杂

    在这里插入图片描述

    但是化简后仍是线性,因此增加层数意义不大

    [图片]

    引入激活函数,从而增加非线性

    [图片]

    反向传播计算梯度,使用链式法则
    [图片]

    [图片]

    反向传播过程

    [图片]

    Tensor in PyTorch

    Tenso(张量):PyTorch中存储数据的基本元素。
    Tensor两个重要的成员,data和grad。(grad也是个张量)

    课堂练习:线性模型 Linear Model

    实现代码

    import torch
    
    # 已知数据:
    x_data = [1.0,2.0,3.0]
    y_data = [2.0,4.0,6.0]
    # 线性模型为y = wx, 预测x = 4时, y的值
    
    # 假设 w = 1
    w = torch.Tensor([1.0])
    w.requires_grad = True
    
    # 定义模型:
    def forward(x):
            return x*w
    
    # 定义损失函数:
    def loss(x,y):
            y_pred = forward(x)
            return (y_pred - y)**2
    
    print("Prediction before training:",4,'%.2f'%(forward(4)))
    
    for epoch in range(100):
            for x, y in zip(x_data,y_data):
                    l = loss(x,y)
                    l.backward() # 对requires_grad = True的Tensor(w)计算其梯度并进行反向传播,并且会释放计算图进行下一次计算
                    print("\tgrad:%.1f %.1f %.2f" % (x,y,w.grad.item()))
                    w.data = w.data - 0.01 * w.grad.data # 通过梯度对w进行更新
                    w.grad.data.zero_() #梯度清零
            print("Epoch:%d, w = %.2f, loss = %.2f" % (epoch,w,l.item()))
    
    print("Prediction after training:",4,'%.2f'%(forward(4)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 本算法中反向传播主要体现在,l.backward()。调用该方法后w.grad由None更新为Tensor类型,且w.grad.data的值用于后续w.data的更新。
    • l.backward()会把计算图中所有需要梯度(grad)的地方都会求出来,然后把梯度都存在对应的待求的参数中,最终计算图被释放。
    • 取tensor中的data是不会构建计算图的。

    结果

    在这里插入图片描述

    课后练习

    1. 计算y=xw的梯度

    在这里插入图片描述

    2. 计算仿射模型y=xw+b的梯度

    在这里插入图片描述

    3. 使用计算图计算y=w1x^2+w2x+b的梯度

    在这里插入图片描述

    4. 使用Pytorch计算y=w1x^2+w2x+b的梯度

    二次模型 Quadratic Model

    在这里插入图片描述

    代码如下:

    import torch
    
    # 已知数据:
    x_data = [1.0,2.0,3.0]
    y_data = [6.0,11.0,18.0]
    # 线性模型为y = w1x²+w2x+b时, 预测x = 4时, y的值
    
    # 假设 w = 1, b = 1
    w1 = torch.Tensor([1.0])
    w1.requires_grad = True
    w2 = torch.Tensor([1.0])
    w2.requires_grad = True
    b = torch.Tensor([1.0])
    b.requires_grad = True
    
    # 定义模型:
    def forward(x):
            return x*x*w1+x*w2+b
    
    # 定义损失函数:
    def loss(x,y):
            y_pred = forward(x)
            return (y_pred - y)**2
    
    print("Prediction before training:",4,'%.2f'%(forward(4)))
    
    for epoch in range(1000):
            for x, y in zip(x_data,y_data):
                    l = loss(x,y)
                    l.backward() # 对requires_grad = True的Tensor(w)计算其梯度并进行反向传播,并且会释放计算图进行下一次计算
                    w1.data = w1.data - 0.02 * w1.grad.data # 通过梯度对w进行更新
                    w2.data = w2.data - 0.02 * w2.grad.data
                    b.data = b.data - 0.02 * b.grad.data
                    # 梯度清零
                    w1.grad.data.zero_()
                    w2.grad.data.zero_() 
                    b.grad.data.zero_()
            print("Epoch:%d, w1 = %.4f,w2 = %.4f,b = %.4f, loss = %.4f" % (epoch,w1,w2,b,l.item()))
    
    print("Prediction after training:",4,'%.4f'%(forward(4)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    结果:

    在这里插入图片描述


    学习资料

    • https://blog.csdn.net/weixin_43786637/article/details/126117060
    • https://blog.csdn.net/Lilo_/article/details/113522485?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_title~default-9-113522485-blog-126117060.pc_relevant_aa&spm=1001.2101.3001.4242.6&utm_relevant_index=12
    • https://blog.csdn.net/lizhuangabby/article/details/125548170?app_version=5.7.0&code=app_1562916241&csdn_share_tail=%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22125548170%22%2C%22source%22%3A%22qq_43800119%22%7D&ctrtid=0pZiz&uLinkId=usr1mkqgl919blen&utm_source=app

    系列文章索引

    教程指路:【《PyTorch深度学习实践》完结合集】 https://www.bilibili.com/video/BV1Y7411d7Ys?share_source=copy_web&vd_source=3d4224b4fa4af57813fe954f52f8fbe7

    1. 线性模型 Linear Model
    2. 梯度下降 Gradient Descent
    3. 反向传播 Back Propagation
    4. 用PyTorch实现线性回归 Linear Regression with Pytorch
    5. 逻辑斯蒂回归 Logistic Regression
    6. 多维度输入 Multiple Dimension Input
    7. 加载数据集Dataset and Dataloader
    8. 用Softmax和CrossEntroyLoss解决多分类问题(Minst数据集)
    9. CNN基础篇——卷积神经网络跑Minst数据集
    10. CNN高级篇——实现复杂网络
    11. RNN基础篇——实现RNN
    12. RNN高级篇—实现分类
  • 相关阅读:
    PyTorch基础知识学习
    Mac上的utools无法找到本地搜索插件
    Vue3 源码阅读(5):响应式系统 —— Vue2 中的 watch 和 computed
    重装系统会影响到电脑的正常使用吗
    USDR脱锚事件:稳定币碰上房地产,双重buff想不崩都难!
    探索Java世界中的七大排序算法(上)
    C#,机器学习的KNN(K Nearest Neighbour)算法与源代码
    Django学习第一天
    【电驱动】驱动电机系统讲解
    Win10配置Maven环境
  • 原文地址:https://blog.csdn.net/qq_43800119/article/details/126415332
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号