码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • Pytorch代码入门学习之分类任务(三):定义损失函数与优化器


    目录

    一、定义损失函数

    1.1 代码

    1.2 损失函数简介

    1.3 交叉熵误差(cross entropy error)

    二、定义优化器

    2.1 代码

    2.2 构造优化器

    2.3 随机梯度下降法(SGD)


    一、定义损失函数

    1.1 代码

    criterion = nn.CrossEntropyLoss()

    1.2 损失函数简介

            神经网络的学习通过某个指标表示目前的状态,然后以这个指标为基准,寻找最优的权重参数。神经网络以某个指标为线索寻找最优权重参数,该指标称为损失函数(loss function)。这个损失函数可以使用任意函数, 但一般用均方误差和交叉熵误差等。损失函数是表示神经网络性能的“恶劣程度”的指标,即当前的神经网络对监督数据在多大程度上不拟合、不一致。这个值越低,表示网络的学习效果越好。

            但是,如果loss很低的话,可能出现过拟合现象。

            过拟合是指训练出来的模型在训练集上表现得很好,但是在测试集上表现的较差,模型训练的误差远小于它在测试集上的误差。

    1.3 交叉熵误差(cross entropy error)

            交叉熵误差如下式所示:

    E = -\sum_k{}t_{k} logy_{k}

             其中,log表示以e为底数的自然对数(log e );yk指神经网络的输出,tk是正确解标签。并且,tk中只有正确解标签的索引为1,其他均为0(one-hot表示)。 因此,上式实际上只计算对应正确解标签的输出的自然对数。比如,假设正确解标签的索引是“2”,与之对应的神经网络的输出是0.6,则交叉熵误差 是−log 0.6 = 0.51;若“2”对应的输出是0.1,则交叉熵误差为−log 0.1 = 2.30。因此,交叉熵误差的值是由正确解标签所对应的输出结果决定的。

            交叉熵误差函数需要两个参数,第一个是输入参数(预测值),第二个是正确值。

    二、定义优化器

    2.1 代码

    1. import torch.optim as optim
    2. optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

    2.2 构造优化器

            optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9):第一个参数是需要更新的参数,第二个参数是指学习率(指每次更新学习率下降的大小),第三个参数为动量;

    2.3 随机梯度下降法(SGD)

            用数学式子可以把SGD写为如下的式:

            其中,W记为需要更新的权重参数,\frac{\partial L}{\partial W}是指损失函数关于W的梯度,\eta表示学习率,一般情况下会取为0.01或0.001这类事先决定好的值。式子中的“箭头”表示用右边的值更新左边的值。

            SGD较为简单,且容易实现,但是在解决某些问题时可能没有效率。SGD是朝着梯度方向只前进一定距离的简单方法,且梯度的方法并没有指向最小值的方向。

            参考:004 第一个分类任务2_哔哩哔哩_bilibili

  • 相关阅读:
    Hazelcast系列(五):Multicast发现机制
    6 种创新的人工智能在牙科领域的应用
    马化腾去年年薪同比下降 25%,腾讯的下一步怎么走?
    vue实现鼠标经过显示悬浮框效果,使用Vue的v-show指令和CSS的:hover伪类,利用Vue的数据绑定
    【C++】STL 标准模板库 ③ ( STL 容器简介 | STL 容器区别 | STL 容器分类 | 常用的 STL 容器 )
    MongoDB综合实战篇(超容易)
    chrome 插件开发指南
    179.Hive(一):hive的基础概念,hive的安装和启动
    JavaScript入门——(5)函数
    LQ0141 纸张尺寸【水题】
  • 原文地址:https://blog.csdn.net/m0_53096519/article/details/134062960
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号