• Pytorch linear 多维 输入的参数


    问题: 由于 在输入lstm 层 每个batch 做了根据输入序列最大长度做了padding,导致每个 batch 的 length 不同。 导致输出 长度不同 。如:(batch, length, output_dim): (12,128,10),(12,111,10). 但是输入 linear 层的时候没有出现问题。

    网站解释:

    官网 pytorch  linear:

     

    • Input: (*, H_{in})(∗,Hin​) where *∗ means any number of dimensions including none and H_{in} = \text{in\_features}Hin​=in_features.  任意维度 number 理解有歧义 (a) number. k可以理解三维,四维。。。 (b) 可以理解 为  某一维度的数 。

    • Output: (*, H_{out})(∗,Hout​) where all but the last dimension are the same shape as the input and H_{out} = \text{out\_features}Hout​=out_features.

    代码解释:

    分别 用三维 和二维输入数组,查看他们参数数目是否一样。

    1. import torch
    2. x = torch.randn(128, 20) # 输入的维度是(12820
    3. m = torch.nn.Linear(20, 30) # 20,30是指维度
    4. output = m(x)
    5. print('m.weight.shape:\n ', m.weight.shape)
    6. print('m.bias.shape:\n', m.bias.shape)
    7. print('output.shape:\n', output.shape)
    8. # ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
    9. ans = torch.mm(x, m.weight.t()) + m.bias
    10. print('ans.shape:\n', ans.shape)
    11. print(torch.equal(ans, output))

    output:

    m.weight.shape:
      torch.Size([30, 20])
    m.bias.shape:
     torch.Size([30])
    output.shape:
     torch.Size([128, 30])
    ans.shape:
     torch.Size([128, 30])
    True
    1. x = torch.randn(128, 30,20) # 输入的维度是(128,3020
    2. m = torch.nn.Linear(20, 30) # 20,30是指维度
    3. output = m(x)
    4. print('m.weight.shape:\n ', m.weight.shape)
    5. print('m.bias.shape:\n', m.bias.shape)
    6. print('output.shape:\n', output.shape)
    ouput:
    m.weight.shape:
      torch.Size([30, 20])
    m.bias.shape:
     torch.Size([30])
    output.shape:
     torch.Size([128, 30, 30])

    结果:

    (128,30,20),和 (128,20) 分别是如  nn.linear(30,20) 层。

    weight.shape 均为: (30,20)

    linear() 参数数目只和 input_dim ,output_dim 有关。

    weight 在源码的定义, 没找到如何计算多维input的代码。

     

  • 相关阅读:
    FLUX.1 实测,堪比 Midjourney 的开源 AI 绘画模型,无需本地显卡,带你免费实战
    采用创新的FPGA 器件来实现更经济且更高能效的大模型推理解决方案
    网络安全-黑客攻击
    SmartInitializingSingleton接口
    信息安全建设之开源安全产品
    深入I/O挖矿
    秒验丨Android客户端集成指南
    Python初识(Python背景知识,安装Python,PyCharm环境搭建)
    Linux系统安装Tomcat一条龙服务
    计算机设计大赛 深度学习的智能中文对话问答机器人
  • 原文地址:https://blog.csdn.net/u013996948/article/details/126406694