• pytorch中detach()函数以及data属性的区别+梯度求导计算


    关于data参数以及detach函数的区别

    在pytorch中tensor的detach和data的区别
    detach()将tensor的内容复制到新的tensor中,而data将tensor的内容复制到当前tensor中。这两种方式在修改tensor的时候,都会将原本的tensor进行修改。重点是detach在你想要进行autograd的时候就会提醒

    关于梯度计算

    常用的求导步骤

    要求导只需要做到两点:

    1. 变量tensor是float或者其他复杂类型;
    2. 将requires_grad指定为True;
    3. 设置Variable,因为只有Variable是可以变的,而tensor则是不可以变的。
      求导的步骤:
    4. 对需要求导的变量设置requires_grad设置以及Variable设置;
    5. 对结果公式进行反向求导,那么那些requires_grad为True的就都会被求导
    6. 导出结果,看你需要那个变量的求导结果,那么就可以直接用变量.grad

    注意:结果为标量才可以进行求导,所以如果所求表达式结果为矩阵,那么就需要把对应矩阵的写在backward里

    需要被计算梯度的变量:

    • 类型为叶子节点(一般自己定义的节点计算叶子节点,而计算的节点就是非叶子节点)
    • requires_grad=True
    • 依赖该tensor的所有tensor的requires_grad=True。
    import torch
    from torch.autograd import Variable
    
    #生成一个内容为[2,3]的张量,Varibale 默认是不要求梯度的,如果要求梯度,
    #需要加上requires_grad=True来说明
    a = Variable(torch.Tensor([[2,3]]),requires_grad=True)
    print(a.type())
    w = Variable(torch.ones(2,1),requires_grad=True)
    out = torch.mm(a,w)
    
    #括号里面的参数要传入和out维度一样的矩阵
    #这个矩阵里面的元素会作为最后加权输出的权重系数
    out.backward()
    # out.backward()
    print("gradients are:{}".format(w.grad.data))
    print("gradients are:{}".format(a.grad.data))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    求导

    更通用的方式

    那么就会有一种更为通用的求导方式:

    1. 设置变量Variable以及requires_grad
    2. 设置表达式以及表达式对应的变量;
    3. 对表达式变量进行求导:表达式变量.backward(对应矩阵表达)
      • 比如:out.backward(torch.FloatTensor([[1,1],[1,1]]))
    4. 使用对应变量的导数:变量.grad
      • 如果需要是requires_grad以及在表达式的变量
    import torch
    from torch.autograd import Variable
    
    #生成一个内容为[2,3]的张量,Varibale 默认是不要求梯度的,如果要求梯度,
    #需要加上requires_grad=True来说明
    a = Variable(torch.Tensor([[2,3],[1,2]]),requires_grad=True)
    print(a.type())
    w = Variable(torch.ones(2,2),requires_grad=True)
    out = torch.mm(a,w)
    
    #括号里面的参数要传入和out维度一样的矩阵
    #这个矩阵里面的元素会作为最后加权输出的权重系数
    out.backward(torch.FloatTensor([[1,1],[1,1]]))
    # out.backward()
    print("gradients are:{}".format(w.grad.data))
    print("gradients are:{}".format(a.grad.data))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述

  • 相关阅读:
    C#实现二叉排序树定义、插入、构造
    订单超时未支付自动取消8种实现方案
    MySQL基础篇【第六篇】| 存储引擎、事务、索引、视图、DBA命令、数据库设计三范式
    1314. 矩阵区域和-矩阵前缀和算法
    使用 SQL 加密函数实现数据列的加解密
    【***二叉树***】
    【浅学Java】Spring的创建和使用
    顺序表和链表
    PTA题目 寻找250
    常见国密算法简介
  • 原文地址:https://blog.csdn.net/weixin_42295969/article/details/126383682