• 【机器学习】PyTorch中 tensor.detach() 和 tensor.data 的区别


    tensor.data的用法举例
    import torch
    x = torch.ones(1,requires_grad=True)
    print(x)
    # tensor([1.], requires_grad=True)
    
    • 1
    • 2
    • 3
    • 4
    y = x * x
    z = x.data
    z *= 100
    print(x)
    # tensor([100.], requires_grad=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    这里可以看出,x 的数值被 z *= 100 修改了。

    接下来再对 x 求导,可以得到:

    y.backward()
    print(x.grad)
    # tensor([200.])
    
    • 1
    • 2
    • 3

    这是个很严重的错误,因为 x 已经改变了。虽然不会报错,但是结果却并不正确。

    tensor.detach()的用法举例
    import torch
    x = torch.ones(1,requires_grad=True)
    print(x)
    # tensor([1.], requires_grad=True)
    
    • 1
    • 2
    • 3
    • 4
    y = x * x
    z = x.detach()
    z *= 100
    print(x)
    # tensor([100.], requires_grad=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    这里可以看出,x 的数值同样被 z *= 100 修改了。

    接下来再对 x 求导,可以得到:

    y.backward()
    print(x.grad)
    # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
    
    • 1
    • 2
    • 3

    这里报错是是因为 autograd 追踪求导的时候发现数据已经发生改变,被覆盖。

    结论

    .data是不安全的,.detach() 是安全的,推荐使用 .detach() 来实现数据的脱离。

  • 相关阅读:
    软件测试面试之问——角色扮演
    python细节随笔
    Ax=y,Ax=0以及非线性方程组的最小二乘解
    RestFul风格
    Tech Lead(技术经理) 带人之道
    1107 Social Clusters 甲级 xp_xht123
    .NET数据交互之生成和读取YAML文件
    TCP协议与UDP协议
    C语言第三十一弹---自定义类型:结构体(下)
    radiobutton的使用
  • 原文地址:https://blog.csdn.net/qq_29931565/article/details/126839404