• pytorch中nn.Parameter()使用方法


    对于nn.Parameter()是pytorch中定义可学习参数的一种方法,因为我们在搭建网络时,网络中会存在一些矩阵,这些矩阵内部的参数是可学习的,也就是可梯度求导的。

    对于一些常用的网络层,例如nn.Conv2d()卷积层nn.LInear()线性层nn.LSTM()循环网络层等,这些网络层在pytorch中的nn模块中已经定义好,所以我们搭建模型时可以直接使用,但是有些自定义网络在pytorch中是没有实现的,我们就需要自定义可学习参数,那就用到了nn.Parameter()这个函数。

    该函数会为我们创建一个矩阵,该矩阵是默认可梯度求导的,之后我们就可以利用这个矩阵进行计算,该函数需要传入的参数是一个tensor,一般我们会传入一个初始化好的tensor。

    下面我们将使用一个简单的线性层作为实例,来理解如何使用nn.Parameter()。

    一、nn.Linear()定义参数

    在类中我们定义了一个线性层,输入维度是10,输出维度是3,对于nn.Linear()层内部已经封装好了nn.Parameter(),所以不需要我们自定义,直接使用即可。

    class Net1(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(10, 3)
        
        def forward(self, x):
            return F.sigmoid(self.linear(x))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    二、nn.Parameter()定义参数

    对于一个线性层,我们会需要两个矩阵,分别是权重W和偏置b,所以我们要用nn.Parameter()定义两个可学习参数,然后传入对应维度的tensor作为参数,之后就可以在forward中定义计算过程。

    class Net2(nn.Module):
        def __init__(self):
            super().__init__()
            self.W = nn.Parameter(torch.randn(10, 3))
            self.b = nn.Parameter(torch.randn(3))
        
        def forward(self, x):
            return F.sigmoid(self.W @ x + self.b)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    三、查看可学习参数

    利用下面代码就可以看定义好的模型中的参数

    model1 = Net1()
    model2 = Net2()
    
    for name, parameters in model1.named_parameters():
        print(name, ':', parameters.size())
        
    for name, parameters in model2.named_parameters():
        print(name, ':', parameters.size())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    linear.weight : torch.Size([3, 10])
    linear.bias : torch.Size([3])
    W : torch.Size([10, 3])
    b : torch.Size([3])
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    控制DAIKIN大金比例压力溢流阀放大器
    ORB-SLAM3论文概述
    如何在阿里云快速配置自动定时重启ECS云服务器?
    Python基础语法:数据分析利器
    RabbitMQ(三)持久化与发布确认
    韦东山老师 RTOS 入门课程(二)理解任务的创建,切换过程
    RSA的一些数论知识
    LayaBox---TypeScript---Symbols
    Error: Node Sass version 7.0.1 is incompatible with ^4.0.0.
    java毕业设计房产置购门户网站Mybatis+系统+数据库+调试部署
  • 原文地址:https://blog.csdn.net/m0_47256162/article/details/127822519