• PyTorch - autograd自动微分


    论文:Automatic Differentiation in Machine Learning: a Survey,自动微分

    参考 PyTorch:AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

    image-20220705094013294

    Loss是标量,雅可比向量积,JVP,

    primitive operation(原始操作):

    在这里插入图片描述

    torch.autograd(),计算梯度

    import torch
    
    x = torch.ones(5)  # input tensor, 输入向量
    print(f"x: {x}")
    y = torch.zeros(3)  # expected output, 标签
    print(f"y: {y}")
    w = torch.randn(5, 3, requires_grad=True)  # 开启自动微分
    print(f"w: {w}")
    b = torch.randn(3, requires_grad=True)
    print(f"b: {b}")
    z = torch.matmul(x, w)+b
    
    loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    使用梯度反向传播算法,back propagation

    backward()是Tensor类的方法,loss是标量直接调用backward(),loss如果是张量,则backward()需要传入张量

    print(f"Gradient function for z = {z.grad_fn}")
    print(f"Gradient function for loss = {loss.grad_fn}")
    
    loss.backward()
    print(w.grad)
    print(b.grad)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    retain_graph=True,保留图,不保留的话,第2次调用会报错:RuntimeError: Trying to backward through the graph a second time

    torch.no_grad()关闭自动求导:

    z = torch.matmul(x, w)+b
    print(z.requires_grad)
    
    with torch.no_grad():
        z = torch.matmul(x, w)+b
    print(z.requires_grad)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    z = z.detach()之后,z.requires_grad是False

    z = z.detach()
    z.requires_grad
    
    • 1
    • 2

    DAG:directed acyclic graph,有向无环图

    张量loss,输入torch.ones_like(inp),反向传播

    retain_graph保留图,可以连续backward()

    梯度置0,np.grad.zero_()

    inp = torch.eye(5, requires_grad=True)
    out = (inp+1).pow(2)
    print(out)
    out.backward(torch.ones_like(inp), retain_graph=True)
    print(f"First call\n{inp.grad}")
    out.backward(torch.ones_like(inp), retain_graph=True)
    print(f"\nSecond call\n{inp.grad}")
    inp.grad.zero_()
    out.backward(torch.ones_like(inp), retain_graph=True)
    print(f"\nCall after zeroing gradients\n{inp.grad}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
  • 相关阅读:
    设计模式-桥接模式
    87.(cesium之家)cesium热力图(贴地形)
    闲人闲谈PS之三十一——新收入准则中的合同损失计提
    基于图搜索的规划算法之A*家族(五):D* 算法
    java设计模式 - 建造者模式
    【前端Vue3】——JQuery知识点总结(超详细)
    晶振的等效电路模型
    [WPF] 如何实现文字描边
    基于FPGA开发板的按键消抖实验
    yolact 环境配置
  • 原文地址:https://blog.csdn.net/u012515223/article/details/125614220