• pytorch_trick(3): PyTorch中可微张量的in-place operation问题解决方法


    PyTorch中可微张量的in-place operation问题解决方法

    关于可微张量的in-place operation(对原对像修改操作)的相关讨论。

    • (1) 叶节点数值修改存在可微张量的in-place operation问题,会导致系统无法区分叶节点和其他节点的问题。
    # in-place operation问题报错
    # 但如果在计算过程中,我们使用in-place operation,让新生成的值替换w原始值,则会报错
    
    w = torch.tensor(2., requires_grad = True)
    w -= w * 2  
    
    '''
    RuntimeError                              Traceback (most recent call last)
     in 
          1 w = torch.tensor(2., requires_grad = True)
    ----> 2 w -= w * 2
    
    RuntimeError: a leaf Variable that requires grad is being used in an in-place operation."
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    从报错信息中可知,PyTorch中不允许叶节点使用in-place operation,根本原因是会造成叶节点和其他节点类型混乱。

    • 修改w值
      不过,虽然可微张量不允许in-place operation,但却可以通过其他方法进行对w进行修改。
    w = torch.tensor(2., requires_grad = True)
    w.is_leaf           # True,w是叶节点
    
    w = w * 2           # tensor(4., grad_fn=)
    w.is_leaf           # False,w不是叶节点
    
    # 无法通过反向传播求其导数
    w.backward()       
    w.grad
    
    '''
    :9: 
    UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. 
    Its .grad attribute won't be populated during autograd.backward(). 
    If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. 
    If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. 
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    但是该方法会导致叶节点丢失,无法反向传播求导。而在一张计算图中,缺少了对叶节点反向传播求导数的相关运算,计算图也就失去了核心价值。因此在实际操作过程中,应该尽量避免导致叶节点丢失的相关操作。

    (2) 叶节点数值修改方法

    • 使用with torch.no_grad()语句或者torch.detach()方法

      当然,如果出现了一定要修改叶节点的取值的情况,典型的如梯度下降过程中利用梯度值修改参数值时,可以使用此前介绍的暂停追踪的方法,如使用with torch.no_grad()语句或者torch.detach()方法,使得修改叶节点数值时暂停追踪,然后再生成新的叶节点带入计算,如:

    w = torch.tensor(2., requires_grad = True)     
    
    # 利用with torch.no_grad()暂停追踪
    with torch.no_grad():       
        w -= w * 2              # tensor(-2., requires_grad=True)
    
    w.is_leaf                   # True,w是叶节点
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    w = torch.tensor(2., requires_grad = True)
    
    # 利用detach生成新变量
    w.detach_()             # tensor(2.)
    w -= w * 2              # tensor(-2.)
    w.requires_grad = True  # tensor(-2., requires_grad=True)
    
    w.is_leaf               # True,w是叶节点
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 使用.data方法

    当然,此处我们介绍另一种方法,.data来返回可微张量的取值,从在避免在修改的过程中被追踪

    w = torch.tensor(2., requires_grad = True)
    w.data              # tensor(2.),查看张量的数值
    w                   # tensor(2., requires_grad=True),但不改变张量本身可微性
    
    # .data方法,对其数值进行修改
    w.data -= w * 2     
    w                   # tensor(-2., requires_grad=True)
    
    w.is_leaf           # True,w仍然是叶节点
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
  • 相关阅读:
    AI 音辨世界:艺术小白的我,靠这个AI模型,速识音乐流派选择音乐
    fastfds扩容全部操作过程-全是干货
    UEFI FD 文件分析
    css布局总体
    04 python的函数
    C#日志简单框架及实际测试
    binlog的三种格式
    77. 组合
    【luogu P2508】圆上的整点(高斯素数模板)
    css中BFC外边距塌陷解决办法
  • 原文地址:https://blog.csdn.net/weixin_45311418/article/details/138913721