• pytorch中detach()函数以及data属性的区别+梯度求导计算


    关于data参数以及detach函数的区别

    在pytorch中tensor的detach和data的区别
    detach()将tensor的内容复制到新的tensor中,而data将tensor的内容复制到当前tensor中。这两种方式在修改tensor的时候,都会将原本的tensor进行修改。重点是detach在你想要进行autograd的时候就会提醒

    关于梯度计算

    常用的求导步骤

    要求导只需要做到两点:

    1. 变量tensor是float或者其他复杂类型;
    2. 将requires_grad指定为True;
    3. 设置Variable,因为只有Variable是可以变的,而tensor则是不可以变的。
      求导的步骤:
    4. 对需要求导的变量设置requires_grad设置以及Variable设置;
    5. 对结果公式进行反向求导,那么那些requires_grad为True的就都会被求导
    6. 导出结果,看你需要那个变量的求导结果,那么就可以直接用变量.grad

    注意:结果为标量才可以进行求导,所以如果所求表达式结果为矩阵,那么就需要把对应矩阵的写在backward里

    需要被计算梯度的变量:

    • 类型为叶子节点(一般自己定义的节点计算叶子节点,而计算的节点就是非叶子节点)
    • requires_grad=True
    • 依赖该tensor的所有tensor的requires_grad=True。
    import torch
    from torch.autograd import Variable
    
    #生成一个内容为[2,3]的张量,Varibale 默认是不要求梯度的,如果要求梯度,
    #需要加上requires_grad=True来说明
    a = Variable(torch.Tensor([[2,3]]),requires_grad=True)
    print(a.type())
    w = Variable(torch.ones(2,1),requires_grad=True)
    out = torch.mm(a,w)
    
    #括号里面的参数要传入和out维度一样的矩阵
    #这个矩阵里面的元素会作为最后加权输出的权重系数
    out.backward()
    # out.backward()
    print("gradients are:{}".format(w.grad.data))
    print("gradients are:{}".format(a.grad.data))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    求导

    更通用的方式

    那么就会有一种更为通用的求导方式:

    1. 设置变量Variable以及requires_grad
    2. 设置表达式以及表达式对应的变量;
    3. 对表达式变量进行求导:表达式变量.backward(对应矩阵表达)
      • 比如:out.backward(torch.FloatTensor([[1,1],[1,1]]))
    4. 使用对应变量的导数:变量.grad
      • 如果需要是requires_grad以及在表达式的变量
    import torch
    from torch.autograd import Variable
    
    #生成一个内容为[2,3]的张量,Varibale 默认是不要求梯度的,如果要求梯度,
    #需要加上requires_grad=True来说明
    a = Variable(torch.Tensor([[2,3],[1,2]]),requires_grad=True)
    print(a.type())
    w = Variable(torch.ones(2,2),requires_grad=True)
    out = torch.mm(a,w)
    
    #括号里面的参数要传入和out维度一样的矩阵
    #这个矩阵里面的元素会作为最后加权输出的权重系数
    out.backward(torch.FloatTensor([[1,1],[1,1]]))
    # out.backward()
    print("gradients are:{}".format(w.grad.data))
    print("gradients are:{}".format(a.grad.data))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述

  • 相关阅读:
    lattice crosslink开发板mipi核心板csi测试dsi屏lif md6000 fpga
    AS3 event flow 事件冒泡机制
    vue.js具名插槽
    【数学】焦点弦定理(?)
    若依前后端分离版整合Mybatis-puls
    前后端分离时后端shiro权限认证
    [附源码]java毕业设计班级风采网站
    C语言经典例题-11
    检测Windows环境中的内部威胁
    mongodb备份还原指南
  • 原文地址:https://blog.csdn.net/weixin_42295969/article/details/126383682