简单来说,你可以把tensor
看作是一个通用的数据结构,而nn.Parameter
看作是一种特殊的tensor
,这种tensor
可以被优化以提高模型的性能。在创建模型参数时,你应该使用nn.Parameter
而不是直接使用tensor
,因为这样可以确保模型参数能够被正确地用于训练和优化。
具体而言,nn.Parameter类型,
nn.Parameter是一个Parameter类,会自动把它包含的Tensor标记为需要求梯度的参变量。
而普通Tensor默认是不求梯度的,需要使用requires_grad_()来手动指定需要求导。
使用nn.Parameter可以让shape等向量自动参与求导和回传过程,从而被优化器更新。
nn.Parameter可以像普通Module的参数一样被添加到nn.Module中。
而普通Tensor需要以字典的形式加入module.state_dict中才能被当作参数。
所以直接使用nn.Parameter可以更方便地将shape等向量作为模块可优化的参数。