• Pytorch squeeze() unsqueeze() 用法


    简介

    torch.squeeze(input, dim=None, out=None):对数据的维度进行压缩,去掉维数为1的的维度。
    squeeze函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。
    squeeze(0):代表若第一维度值为1则去除第一维度,例如 a.squeeze(0),a 为 torch.tensor() 格式张量。
    squeeze(1):代表若第二维度值为1则去除第二维度
    squeeze(-1):去除最后维度值为1的维度

    torch.unsqueeze (input, dim=None, out=None):对数据的维度进行扩容,即升维。

    使用格式可以是torch.unsqueeze(x, 0),也可为是x.unsqueeze(0)

    实例代码

    a = torch.Tensor(1, 3)
    print(a)
    print(a.squeeze(0))
    print(a.squeeze(1))
    
    b = torch.Tensor(2, 3)
    print(b)
    print(b.squeeze(0))
    print(b.squeeze(1))
    
    c = torch.Tensor(3, 1)
    print(c)
    print(c.squeeze(0))
    print(c.squeeze(1))
    
    x = torch.tensor([1, 2, 3, 4])
    print(x)
    print(torch.unsqueeze(x, 0))
    print(torch.unsqueeze(x, 1))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    过程解析

    定义张量 a,为 2 维,第一维度有 1 个元素,第二维度有 3 个元素。
    输出:tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
    通过 a.squeeze(0) 对第一维度进行降维,此时第一维度有 1 个元素,可降维,第一维度消失,第二维度自动变成第一维度有三个元素,与 a 相比,即消失了一层 “[]”。
    输出:tensor([2.6994e-30, 2.4164e-13, 1.8392e-13])
    通过 a.squeeze(1) 对第二维度进行降维,此时第一维度有 3 个元素,不可降维,则不做操作,输出与 a 相同。
    输出:tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])

    定义张量 b,为 2 维,第一维度有 2 个元素,第二维度有 3 个元素。
    第一、二维度均不可降维,因为三次输出相同。
    输出:

    tensor([[0., 0., 0.],
            [0., 0., 0.]])
    tensor([[0., 0., 0.],
            [0., 0., 0.]])
    tensor([[0., 0., 0.],
            [0., 0., 0.]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    定义张量 c,为 2 维,第一维度有 3 个元素,第二维度有 1 个元素。
    输出:

    tensor([[0.0000e+00],
            [       nan],
            [5.2781e-24]])
    
    • 1
    • 2
    • 3

    通过 c.squeeze(0) 对第一维度进行降维,此时第一维度有 3 个元素,不可降维,则不做操作,输出与 c 相同。
    输出:

    tensor([[0.0000e+00],
            [       nan],
            [5.2781e-24]])
    
    • 1
    • 2
    • 3

    通过 c.squeeze(1) 对第二维度进行降维,此时第二维度有 1 个元素,可降维,第二维度消失,第二维度数值自动进入第一维度中。
    输出:
    tensor([0.0000e+00, nan, 5.2781e-24])

    定义张量 x,为 1 维,其中数值依次为 1, 2, 3, 4。
    输出:tensor([1, 2, 3, 4])
    通过 x.unsqueeze(0) 于第一维度位置增加一个维度,使原张量变成 2 维,维度变为 (1, 4)。与 x 相比,即增加了一层 “[]”。
    输出:tensor([[1, 2, 3, 4]])
    通过 x.unsqueeze(1) 于第二维度位置增加一个维度,使原张量变成 2 维,维度变为 (4, 1)。
    输出:tensor([[1], [2], [3], [4]])

    运行结果

    tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
    tensor([2.6994e-30, 2.4164e-13, 1.8392e-13])
    tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
    tensor([[0., 0., 0.],
            [0., 0., 0.]])
    tensor([[0., 0., 0.],
            [0., 0., 0.]])
    tensor([[0., 0., 0.],
            [0., 0., 0.]])
    tensor([[0.0000e+00],
            [       nan],
            [5.2781e-24]])
    tensor([[0.0000e+00],
            [       nan],
            [5.2781e-24]])
    tensor([0.0000e+00,        nan, 5.2781e-24])
    tensor([1, 2, 3, 4])
    tensor([[1, 2, 3, 4]])
    tensor([[1],
            [2],
            [3],
            [4]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
  • 相关阅读:
    使用Spring AOP实现系统操作日志记录
    Dockerfile文件解释
    树莓派 交叉编译工具链的安装
    Vue源码cached解析
    工作中对InheritableThreadLocal使用的思考
    CASIO fx4850万能坐标计算程序
    (附源码)springboot青少年公共卫生教育平台 毕业设计 643214
    QQ2 微信红包
    洛谷-收集邮票-(期望dp+期望的平方+平方的期望)
    odoo 报表
  • 原文地址:https://blog.csdn.net/weixin_43820352/article/details/125995034