码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 损失函数loss和优化器optimizer


    损失函数与优化器的关联_criterion(outputs, labels)_写代码_不错哦的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/shenjianhua005/article/details/123971915?ops_request_misc=&request_id=6583569ecbdc4daf89dbf2d43eac9242&biz_id=&utm_medium=distribute.pc_search_result.none-task-blog-2~all~koosearch~default-2-123971915-null-null.142^v93^koosearch_v1&utm_term=%E6%80%8E%E4%B9%88%E6%A0%B9%E6%8D%AE%E6%8D%9F%E5%A4%B1%E5%BA%A6%E4%BC%98%E5%8C%96&spm=1018.2226.3001.4187

    loss与optimizer没有任何关联(直观上),其实它们并不需要直接联系,它们是通过 Tensor 这个类来达到间接联系的。

    首先,net网络中的参数都是tensor,一个 tensor 里面有两个地址,一个是存放的这个tensor当前实实在在的值,比如赋值为10,还有一个存放的是10求导后的值(  .grad  ),就是导数。当然,如果没求导,另一个存放的是None。

    当我们进行计算loss.backward()的时候,其实就是进行反向链式求导,这个求导是对net中的参数进行求导的,这里面的参数就是tensor,其有两个地址,分别存放当前值和反向求导的值,loss.backward()后,这个时候就每个参数里面都有导数,然后optimizer其实就是根据net每个参数的导数进行优化(在最开始定义的时候就已经绑定optimizer与net的参数了),这也就关联了loss与optimizer了。

    optimizer.step()是更新参数

    刚刚写完这个,突然想到,loss是怎么跟net中参数联系起来的,其实可以这么来看:

    y=w1X1+w2X2+w3X3

    我们在计算 loss = criterion(out, input)时,这里的out就等于y就等于w1X1+w2X2+w3X3,(虽然y是一些具体的值,但是这些值是由w1X1+w2X2+w3X3构成的),所以 losss.backward()的时候就是更新w1,w2,w3,所以这就关联了。


    1. loss = softmax_entropy(outputs).mean(0)
    2. loss.backward()
    3. optimizer.step()
    4. optimizer.zero_grad()
    1. loss = softmax_entropy(outputs).mean(0): 这一行代码计算了模型输出的损失。首先,对输出进行softmax操作,将其转换为概率分布。接下来,使用交叉熵损失函数计算每个样本的损失。最后,通过mean(0)对样本的损失进行平均,得到一个标量的损失值。

    2. loss.backward(): 这行代码触发了反向传播过程。它根据计算图以及链式法则,计算了损失相对于模型参数的梯度。这个过程通过自动微分(autograd)机制来完成,梯度信息会被累积在每个参数的.grad属性中。

    3. optimizer.step(): 这一步用来更新模型的参数。优化器根据计算得到的梯度信息,根据所选的优化算法(如随机梯度下降法SGD、Adam等),更新模型中的可学习参数。这个过程会更新模型中的权重和偏置等参数,使其朝着减小损失的方向调整。

    4. optimizer.zero_grad(): 这一行代码将模型参数的梯度清零。在进行下一轮迭代之前,需要将之前一轮迭代中计算的梯度进行清除。它是必要的,因为PyTorch默认会在反向传播过程中累积梯度,如果不清零,梯度将会累积在后续迭代中,导致结果不正确的参数更新。

  • 相关阅读:
    Mysql 45讲学习笔记(二十五)MYSQL保证高可用
    Java&数组
    深入理解联邦学习——联邦学习的分类
    使用CMake创建CUDA工程
    java计算机毕业设计ssm+jsp计算机视频学习网站
    Kubernetes(k8s)的流量负载组件Service的ClusterIP类型讲解与使用
    8.4 数据结构——选择排序
    Hive Metastore源码新增Thrift API方法
    云服务器 通过docker安装配置Nacos 图文操作
    小公司招聘程序员要求985研究生,网友:这点钱,专科都不去
  • 原文地址:https://blog.csdn.net/zhu_ba/article/details/132861027
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号