码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 《PyTorch深度学习实践》第三讲 梯度下降算法


    《PyTorch深度学习实践》第三讲 梯度下降算法

    • 问题描述
    • 梯度下降
      • 问题分析
      • 编程实现
        • 代码
        • 实现效果
    • 随机梯度下降
      • 问题分析
      • 编程实现
        • 代码
        • 实现效果
    • 参考资料

    问题描述

    在这里插入图片描述

    梯度下降

    问题分析

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    编程实现

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    代码

    import matplotlib.pyplot as plt
    
    # 训练集数据
    x_data = [1.0, 2.0, 3.0]
    y_data = [2.0, 4.0, 6.0]
    
    # 设置初始权重猜测
    w = 1.0
    
    # 前馈计算
    def forward(x):
        return x * w
    
    # 计算损失
    def cost(xs, ys):
        cost = 0
        for x, y in zip(xs, ys):
            y_pred = forward(x)
            cost += (y_pred - y) ** 2
        return cost / len(xs)
    
    # 计算梯度
    def gradien(xs, ys):
        grad = 0
        for x, y in zip(xs, ys):
            grad += 2 * x * (x * w - y)
        return grad / len(xs)
    
    print('Predict(before training)', 4, forward(4))
    
    # 存放每轮的数据
    cost_list = []
    epoch_list = []
    
    # 训练过程
    for epoch in range(100):  # 训练100轮
        cost_val = cost(x_data, y_data)
        grad_val = gradien(x_data, y_data)   # 更新梯度
        w -= 0.01 * grad_val    # 0.01 学习率
        print('Epoch:', epoch, 'w = ', w, 'loss = ', cost_val)
        cost_list.append(cost_val)
        epoch_list.append(epoch)
    
    print('Predict(after training)', 4, forward(4))
    
    # 绘图展示
    plt.plot(epoch_list, cost_list)
    plt.xlabel('Epoch')
    plt.ylabel('Cost')
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    实现效果

    在这里插入图片描述

    随机梯度下降

    使用随机梯度下降对上述问题进行求解,随机梯度下降法和梯度下降法的主要区别在于:
    1、损失函数由计算所有训练数据的损失,更改为计算一个训练数据的损失。
    2、梯度函数由计算所有训练数据的梯度,更改为计算一个训练数据的梯度。

    问题分析

    在这里插入图片描述

    编程实现

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    代码

    import matplotlib.pyplot as plt
    
    # 训练集数据
    x_data = [1.0, 2.0, 3.0]
    y_data = [2.0, 4.0, 6.0]
    
    # 设置初始权重猜测
    w = 1.0
    
    # 前馈计算
    def forward(x):
        return x * w
    
    # 计算损失
    def loss(x, y):
        y_pred = forward(x)
        return (y_pred - y) ** 2
    
    # 计算梯度
    def gradien(x, y):
        return 2 * x * (x * w - y)
    
    print('Predict(before training)', 4, forward(4))
    
    # 存放每轮的数据
    loss_list = []
    epoch_list = []
    
    # 训练过程
    for epoch in range(100):  # 训练100轮
        for x, y in zip(x_data, y_data):
            grad = gradien(x, y)
            w = w - 0.01 * grad
            print('\tgrad:', x, y, grad)
            l = loss(x, y)
        epoch_list.append(epoch)
        loss_list.append(l)
    
    print('Predict(after training)', 4, forward(4))
    
    # 绘图展示
    plt.plot(epoch_list, loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46

    实现效果

    在这里插入图片描述

    参考资料

    传送门梯度下降算法

  • 相关阅读:
    Electron + vue搭建项目
    API对接需求如何做需求调研,需要注意什么?
    notepad++中文出现异体汉字,怎么改正
    外包干了三年,我承认我确实废了……
    laravel高校毕业实习管理系统
    【Vue】模板语法,事件处理器及综合案例、自定义组件、组件通信
    品牌广告,如何规避风险
    新手看过来----代码自测通过但作业通不过
    JAVA动漫周边产品销售管理系统计算机毕业设计Mybatis+系统+数据库+调试部署
    mysql 与 Oracle 的区别,oracle 与 mysql分页查询的区别
  • 原文地址:https://blog.csdn.net/m0_46669450/article/details/133826082
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号