• Torch截断一部分后是否能梯度回传


    1. import torch
    2. from torch import optim
    3. import torch.nn as nn
    4. class g(nn.Module):
    5. def __init__(self):
    6. super(g, self).__init__()
    7. self.k = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=1, padding=0, bias=False)
    8. def forward(self, z):
    9. return self.k(z)
    10. c = 2
    11. h = 5
    12. w = 5
    13. z = torch.rand( (1,c , h , w)).float().view(1, c, h, w)*100
    14. z.requires_grad = True
    15. k = g()
    16. optim = optim.Adam(k.parameters(), lr=1)
    17. optim.zero_grad()
    18. r = k(z)
    19. r= r[:,:,:3,:3]
    20. r = r.sum()
    21. loss = (r - 1) * (r - 1)
    22. for name,v in k.named_parameters():
    23. print(name,v)
    24. print(z)
    25. print("*********************")
    26. loss.backward()
    27. optim.step()
    28. for name,v in k.named_parameters():
    29. print(name,v)
    30. print(z)

    输出:


    tensor([[[[-0.0464]],

             [[ 0.4256]]]], requires_grad=True)
    tensor([[[[65.6508, 65.0099, 38.5205, 78.4769, 31.6377],
              [27.1530,  5.7923, 23.9614, 59.5419,  3.5597],
              [69.9373, 29.7657, 91.4004, 85.5130, 65.2210],
              [62.6357, 23.9004, 95.3394, 59.5155, 48.1762],
              [98.7728, 97.2193, 66.3625, 65.0421, 22.0612]],

             [[19.3582,  2.4226, 47.2068, 20.1124, 31.9324],
              [23.4966,  5.0654, 12.4682, 35.3092, 90.3394],
              [ 8.4709, 91.5994, 79.7592, 93.8652, 92.6337],
              [49.0805, 63.9460, 81.2459, 63.4729, 77.1670],
              [17.8333, 18.6162, 44.9271, 44.8790,  3.6609]]]], requires_grad=True)
    *********************
    k.weight Parameter containing:
    tensor([[[[-1.0464]],

             [[-0.5744]]]], requires_grad=True)
    tensor([[[[65.6508, 65.0099, 38.5205, 78.4769, 31.6377],
              [27.1530,  5.7923, 23.9614, 59.5419,  3.5597],
              [69.9373, 29.7657, 91.4004, 85.5130, 65.2210],
              [62.6357, 23.9004, 95.3394, 59.5155, 48.1762],
              [98.7728, 97.2193, 66.3625, 65.0421, 22.0612]],

             [[19.3582,  2.4226, 47.2068, 20.1124, 31.9324],
              [23.4966,  5.0654, 12.4682, 35.3092, 90.3394],
              [ 8.4709, 91.5994, 79.7592, 93.8652, 92.6337],
              [49.0805, 63.9460, 81.2459, 63.4729, 77.1670],
              [17.8333, 18.6162, 44.9271, 44.8790,  3.6609]]]], requires_grad=True)

  • 相关阅读:
    初阶数据结构学习记录——아홉 二叉树和堆(2)
    python系列教程193——参数传递
    【JAVA】String类
    Matlab:设置命令历史记录预设项
    电影评分数据分析案例-Spark SQL
    Spring Boot 之配置文件
    typescript核心
    【主流技术】日常工作中关于 JSON 转换的经验大全(Java)
    SOCKS5代理与网络安全:如何安全地进行爬虫操作
    Nginx(一)介绍Nginx、正向代理和实现反向代理的两个实例
  • 原文地址:https://blog.csdn.net/qq_39861441/article/details/133610441