本篇记录pytorch中梯度相关内容,叶子节点定义与detach的使用。
神经网络的参数更新依赖于梯度计算与反向传播。pytorch中的tensor张量存在requires_grad属性,requires_grad=True时,计算图的末端节点进行backward()会自动计算该张量的梯度。
举例:
import torch
a = torch.tensor([1, 1, 1], requires_grad=True, dtype=torch.float32)
b = a.mean()
b.backward()
print(a.grad) # db/da
即使张量的requires_grad属性为真时,该张量也不一定能够直接通过grad属性访问计算图backward累计的梯度。只有叶子节点同时requires_grad=True的张量,才能通过.grad访问累计梯度;非叶子节点的计算图张量,其梯度在计算后马上被删除以节省内存。
一个经典的梯度计算与参数更新的顺序是:
叶子节点的判断:用户自行创建的requires_grad=True的张量,比如各种网络层nn.Linear, nn.Conv2d的神经元weight, bias或者卷积核参数;requires_grad=False的张量也是叶子节点,但因为它不需要计算梯度,因此认为是游离在计算图之外的。
通过叶子节点运算得来的节点,都是非叶子节点。非叶子节点添加了一个grad_fn属性,记录了该节点产生时使用的运算函数,用于反向计算其梯度(这个梯度只能通过retain_grad函数或者hook机制获得,因为它计算后即被释放)。
上面提到,requires_grad=False的张量梯度不会累计,如果要使得其requires_grad=True,可以使用requires_grad_(True);
requires_grad=True的非叶子节点不能直接通过requires_grad_(False)修改其梯度属性,可以使用detach剥离其梯度属性,剥离后的张量与原张量内存相同,但不会累计梯度,其属性里也没有了grad_fn。
import torch
import torch.nn as nn
import torch.nn.functional as F
a = torch.tensor([1, 1, 1], requires_grad=True, dtype=torch.float32)
b = torch.tensor([1, 1, 1], requires_grad=False, dtype=torch.float32)
b.requires_grad_(True)
c = (a + 2 * b.detach()).sum()
c.backward()
print(a.grad, b.grad) # tensor([1., 1., 1.]) None