• torch.nn.parameter详解


    :---------------------------------------------------------------------------------------------------------------------:

    目录:

    参考:

    Parameter — PyTorch 1.12 documentation

    1、parameter基本解释:

    CLASS torch.nn.parameter.Parameter(data=None, requires_grad=True)
    """
    A kind of Tensor that is to be considered a module parameter.
    
    Parameters are Tensor subclasses, that have a very special property when used with Module s - when they’re assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in parameters() iterator. Assigning a Tensor doesn’t have such effect. This is because one might want to cache some temporary state, like last hidden state of the RNN, in the model. If there was no such class as Parameter, these temporaries would get registered too.
    
    data (Tensor) – parameter tensor.
    
    requires_grad (bool, optional) – if the parameter requires gradient. See Locally disabling gradient computation for more details. Default: True
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    torch.nn.parameter.Parameter 类用于Module里面自定义参数,当其作为Module的属性时,会自动添加到模型的参数列表中,可以通过parameters()迭代器读取:例如RNN的最后一个隐藏状态,Transfermor、VIT、GNN都会用到的。

    参数data:指的是Tensor

    参数requires_grad:指的是是否需要自动计算梯度(根据实际情况来定,如果需要学习的权重,需要自动计算梯度,如果不参与学习,只是作为保存变量则不需要自动计算梯度)

    2、参数requires_grad的深入理解:

    2.1 Parameter级别的requires_grad

    Autograd mechanics — PyTorch 1.12 documentation

    requires_grad参数和pytorch的自动计算梯度的机制有关。requires_grad是一个决定是否需要反向传播时候计算梯度的标志,如果True,则在前向传递期间,将节点记录在后向图中。在后向传递 (.backward()) 时只有 requires_grad=True 的叶张量才会将梯度累积到它们的 .grad 字段中。 注意:即使每个张量都有这个标志,设置它只对leaf tensors(没有 grad_fn 的张量,例如,nn.Module 的参数)有意义。很明显所有no leaf tensors(具有 grad_fn 的张量,与leaf tensors有关的后向图的张量)都会自动具有 require_grad=True,no leaf tensors计算梯度作为中间结果来计算叶tensors的 grad 。 设置 requires_grad 可以控制模型的哪些部分需要梯度计算。举个例子:

    例如,如果需要在模型微调期间冻结部分预训练模型。 要冻结模型的某些部分,只需将 .requires_grad(False) 应用于应用于不想更新的参数。如上所述,由于使用这些参数作为输入的计算不会记录在前向传递中,因此它们不会在后向传递中更新其 .grad 字段,因为它们不会成为第一个后向图的一部分节点。

    2.2Module级别的requires_grad标志

    根据需要,也可以使nn.Module.requires_grad() 在模块级别设置requires_grad。当应用于模块时, .requires_grad_() 会影响模块的所有参数(默认情况下 requires_grad=True )。

  • 相关阅读:
    0基础学Java(30)
    5.什么是Spring的依赖注入(DI)?IOC和DI的区别是什么
    lv11 嵌入式开发 ARM指令集中(汇编指令集) 6
    按照Mybatis的反射和自动代理,使用JDBC进行模拟
    执行jar包中指定main方法
    2021年ICPC国际大学生程序设计竞赛暨陕西省第九届大学生程序设计竞赛 L:String Games
    conda环境下XZ_5.1.2alpha not found解决方案
    二十七、Java 枚举(enum)
    软件测试简历项目经验怎么写?大厂面试手拿把掐
    JSONObject和JSONArray的基本使用
  • 原文地址:https://blog.csdn.net/KPer_Yang/article/details/126293681