torch.min()接受的参数如下:
不指定维度时, torch.min() 输出整个张量中所有元素的最小值
import torch
# 创建一个张量
x = torch.tensor([1, 2, 3, 4, 5])
# 计算最小值
min_value = torch.min(x)
print(min_value) # output: tensor(1)
当指定 dim 参数时,torch.min() 会返回沿指定维度的最小值以及对应的索引。
import torch
# 创建一个 2D 张量
x = torch.tensor([[1, 2, 3],
[4, 0, 6]])
# 沿每列计算最小值
min_values, min_indices = torch.min(x, dim=0)
print("Min values along columns:", min_values)
print("Indices of min values along columns:", min_indices)
# 沿每行计算最小值
min_values, min_indices = torch.min(x, dim=1)
print("Min values along rows:", min_values)
print("Indices of min values along rows:", min_indices)
输出的结果为:
Min values along columns: tensor([1, 0, 3])
Indices of min values along columns: tensor([0, 1, 0])
Min values along rows: tensor([1, 0])
Indices of min values along rows: tensor([0, 1])
当传入两个张量时,torch.min() 会比较两个张量中的每个位置的元素,并返回对应位置的最小值。
例如:
import torch
# 创建两个张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([3, 1, 2])
# 比较两个张量并返回最小值
min_values = torch.min(a, b)
print(min_values) # output: tensor([1, 1, 2])