码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 【PyTorch】深度学习实践之线性模型Linear Model


    本文目录

    • 课堂练习
      • 实现代码:
      • 结果:
    • 课后练习
      • 代码:
      • 结果:
    • 系列文章索引


    课堂练习

    线性模型试图学得一个通过属性的线性组合来进行预测的函数,即:
    在这里插入图片描述
    我们的目标是让输出f(x)和真实值y相比尽可能的小,采用均方误差(Mean Square Error, MSE)作为loss函数,即:
    在这里插入图片描述

    课堂例子:
    拟合如下数据:
    在这里插入图片描述

    实现代码:

    import numpy as np
    import matplotlib.pyplot as plt
    
    # 数据集,相同索引x,y为一个样本
    x_data = [1.0, 2.0, 3.0]
    y_data = [2.0, 4.0, 6.0]
    
    # 模型的前馈,线性方程 y = x·w
    def forward(x): 
        return x * w
    
    # 损失计算
    def loss(x, y):
        y_pred = forward(x) # 根据前馈求y_pred
        return (y_pred - y) * (y_pred - y)
    
    w_list = [] # 权重
    mse_list = [] # 权重对应的损失
    for w in np.arange(0.0, 4.1, 0.1): # 穷举w
        print('w=', w)
        l_sum = 0
    
        # 从x_data,y_data去除x_val,y_val
        for x_val, y_val in zip(x_data, y_data):
            y_pred_val = forward(x_val)
            loss_val = loss(x_val, y_val)
            l_sum += loss_val
            print('\t', x_val, y_val, y_pred_val, loss_val)
        print('MSE=', l_sum / 3)
        w_list.append(w) 
        mse_list.append(l_sum / 3)
        
    # 画图
    plt.plot(w_list,mse_list)
    plt.ylabel('Loss')
    plt.xlabel('w')
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    结果:

    在这里插入图片描述
    在这里插入图片描述

    课后练习

    使用y ^ = x ∗ w + b拟合数据,并画出损失函数图

    参考资料:https://blog.csdn.net/qq_36271858/article/details/115868825?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_title~default-1-115868825-blog-126074917.pc_relevant_aa&spm=1001.2101.3001.4242.2&utm_relevant_index=4

    代码:

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from pylab import *
    # 解决图像无法现实中文问题
    mpl.rcParams['font.sans-serif'] = ['SimHei'] 
    
    #这里设模型函数为y=2x+2
    x_data = [1.0,2.0,3.0]
    y_data = [4.0,6.0,8.0]
    
    # 定义模型
    def forward(x):
        return x * w + b
    
    # 定义损失函数
    def loss(x,y):
        y_pred = forward(x)
        return (y_pred-y)*(y_pred-y)
    
    mse_list = []
    W=np.arange(0.0,4.1,0.1)
    B=np.arange(0.0,4.1,0.1)
    w,b=np.meshgrid(W,B)
    
    l_sum = 0
    for x_val, y_val in zip(x_data, y_data):
        y_pred_val = forward(x_val)
        loss_val = loss(x_val, y_val)
        print('x_val==', x_val,'\ny_val==', y_val,'\ny_pred_val==', y_pred_val, '\nloss_val==',loss_val)
        # 计算同一个w和b下的loss总和
        l_sum += loss_val
    
    # 查找loss最低的w,b取值
    target = {'loss': float('inf'), 'w':0, 'b':0}
    for i in range(l_sum.shape[0]):
        for j in range(l_sum.shape[1]):
           if l_sum[i][j] < target['loss']:
                target['loss'] = l_sum[i][j]
                target['w'] = w[i][j]
                target['b'] = b[i][j]
    
    print('target linear model is y = %.2f * x + %.2f' % (target['w'], target['b']))
    
    # 定义三维坐标轴
    fig = plt.figure()
    ax = Axes3D(fig,auto_add_to_figure=False)
    fig.add_axes(ax)
    # 作图
    ax.plot_surface(w, b, l_sum/3,rstride=1,cstride=1,cmap=plt.cm.coolwarm) #rstride:行之间的跨度;cstride:列之间的跨度;cmap:颜色映射表
    ax.set_xlabel("权重 W")
    ax.set_ylabel("偏置项 B")
    ax.set_zlabel("损失值")
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55

    结果:

    Target liner model为 y = 2.00*x+ 2.00

    在这里插入图片描述


    系列文章索引

    教程指路:【《PyTorch深度学习实践》完结合集】 https://www.bilibili.com/video/BV1Y7411d7Ys?share_source=copy_web&vd_source=3d4224b4fa4af57813fe954f52f8fbe7

    1. 线性模型 Linear Model
    2. 梯度下降 Gradient Descent
    3. 反向传播 Back Propagation
    4. 用PyTorch实现线性回归 Linear Regression with Pytorch
    5. 逻辑斯蒂回归 Logistic Regression
    6. 多维度输入 Multiple Dimension Input
    7. 加载数据集Dataset and Dataloader
    8. 用Softmax和CrossEntroyLoss解决多分类问题(Minst数据集)
    9. CNN基础篇——卷积神经网络跑Minst数据集
    10. CNN高级篇——实现复杂网络
    11. RNN基础篇——实现RNN
    12. RNN高级篇—实现分类
  • 相关阅读:
    应急响应之Windows主机入侵排查
    机器视觉公司怎么可能养我这闲人,连软件加密狗都用不起,项目都用盗版,为什么​?
    005. C++智能指针
    遍历数组的10个高阶函数
    Water 2.6.3 发布,一站式服务治理平台
    玉柴集团用USB Server对U盾远程安全管控
    Android之Gradle和Gradle插件区别及联系
    【MySQL备份】Percona XtraBackup全量备份实战篇
    MySQL进阶实战11,查询缓存
    “维护者都快累死了!”Linux 宣布:LTS 版本的维护期,将从 6 年变回 2 年
  • 原文地址:https://blog.csdn.net/qq_43800119/article/details/126405091
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号