• Pytorch学习:torch.max(input,dim,keepdim=False)


    torch.max()

    torch.max(input)Tensor:返回 input 张量中所有元素的最大值。
    注意输入的必须是张量形式,输出的也为张量形式
    在这里插入图片描述
    在这里插入图片描述
    当输入为tuple类型时,会报错,需要将输入改为tensor类型,输出也为tensor类型
    在这里插入图片描述在这里插入图片描述

    torch.max():官方文档
    torch.max(input,dim,keepdim=False,*,out=None)
    主要参数:

    • input(Tensor)-输入张量。
    • dim(int)-要减少的维度。
    • keepdim(bool)-输出张量是否保留了 dim 。默认值: False 。
      关键字参数:
    • out(tuple,optional)-两个输出张量的结果元组(max,max_indices)

    dim

    对于二维数组来说,dim=0为行,dim=1为列
    在torch.max()中代表要减少的维度(dimension)

    import torch
    
    a = torch.tensor([1, 2, 3, 4])
    max = torch.max(a, dim=0)
    print(max)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    对于以上程序,由于只存在行,所以torch.max(a, dim=0)只能减少的维度为行向量,即dim=0
    在这里插入图片描述
    如果 max = max = torch.max(a, dim=1),则会报错:维度错误
    在这里插入图片描述

    注:如果在减少的行中存在多个最大值,则返回第一个最大值的索引。

    import torch
    
    a = torch.tensor([4, 2, 3, 4])
    max = torch.max(a, dim=0)
    print(max)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述

    keepdim

    输出张量是否保留了 dim,即设置是否保留torch.max(input, dim=0, keepdim=True) 中需要消去的dim。

    如果 keepdim 是 True ,则输出张量的大小与 input 相同,除了在维度 dim 中它们的大小为1。

    dim=0

    二维数组中dim=0代表行,torch.max(a, dim=0)代表消去行,求每列的最大值,keepdim=True则代表保留行

    import torch
    
    a = torch.tensor([[1, 2, 3, 4],
                      [4, 1, 2, 3],
                      [6, 2, 3, 4],
                      [3, 4, 5, 9]])
    
    # dim = 0
    max1_1 = torch.max(a, dim=0, keepdim=False)
    max1_2 = torch.max(a, dim=0, keepdim=True)
    print(max1_1)
    print(max1_2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    在这里插入图片描述
    在这里插入图片描述dim=0,消去的维数为行,即求每列的最大值
    keepdim=False,vlaues=tensor([6, 4, 5, 9])有一个中括号
    keepdim=True,vlaues=tensor([[6, 4, 5, 9]])有两个中括号

    indices代表最大值所处的位置(第一列第三个:2,第一列第四个:3,第三列第四个:3,第四列第四个:3)

    dim=1

    二维数组中dim=1代表列,torch.max(a, dim=0)代表消去列,求每行的最大值,keepdim=True则代表保留列

    import torch
    
    a = torch.tensor([[1, 2, 3, 4],
                      [4, 1, 2, 3],
                      [6, 2, 3, 4],
                      [3, 4, 5, 9]])
    
    # dim = 1
    max2_1 = torch.max(a, dim=1, keepdim=False)
    max2_2 = torch.max(a, dim=1, keepdim=True)
    print(max2_1)
    print(max2_2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    在这里插入图片描述
    在这里插入图片描述
    dim=1,消去的维数为列,即求每行的最大值
    keepdim=False,vlaues=tensor([4, 4, 6, 9])有一个中括号
    keepdim=True,vlaues=tensor([[4], [4], [6], [9]])有两个中括号

    indices代表最大值所处的位置(第一行第四个:3,第二行第一个:0,第三行第一个:0,第四行第四个:3)

    out:返回命名元组 (values, indices)

    values 是给定维度 dim 中 input 张量的每行的最大值。
    indices 是找到的每个最大值(argmax)的索引位置。

  • 相关阅读:
    浅谈Oracle数据库调优(2)
    批量单独下载package.json中的包
    微信小程序 | 动手实现双十一红包雨
    AutoSAR基础:Port与Dio
    【码蹄集新手村600题】判断一个数字是否为完全平方数
    mysql大数据量 分页查询优化
    常见的12种二次曲面方程及可视化
    怎么获取开源的商城源码
    跟着播客学英语-Why I use vim ? part two
    PyQt5快速开发与实战 4.1 QMainWindow
  • 原文地址:https://blog.csdn.net/qq_38473254/article/details/132840277