torch.argmax(input) → LongTensor
参数:
input (Tensor) – 输入的Tensor矩阵
dim (int) – dim表示不同维度。特别的在dim=0表示二维矩阵中的列,dim=1在二维矩阵中的行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么dim=0就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,以此类推。
举一些例子说明:
import torch
x = torch.asarray([3, 2, 5, 1])
y = torch.argmax(x) # 对应于x中最大元素的索引值
print(x, y)
返回最大值索引,也就是5的索引位置2.
import torch
x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x) # 对应于x中最大元素的索引值
print(x, y)
该函数默认将输入矩阵排变成一个一维向量,然后找出这个一维向量里面最大值的索引。
对于dim
这个参数可以这样理解:
下边代码例子输入x为torch.Size([2, 4])
,dim=0
时把2变成1,返回每列最大索引,dim=1
时把4变为1,返回每行最大索引。
函数返回其他所有维在这个维度上面张量最大值的索引。
import torch
x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x, dim=0) # 对应于x中最大元素的索引值
print(y)
import torch
x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x, dim=1) # 对应于x中最大元素的索引值
print(y)