• NNAPI ARGMAX 定义


    Android NNAPI - Paddle - TensorFlow - PyTorch ArgMax and ArgMin 的定义与计算过程

    1. Android NNAPI

    Android NDK
    https://developer.android.com/ndk

    Neural Networks API
    https://developer.android.com/ndk/guides/neuralnetworks

    Android NDK API Reference
    https://developer.android.com/ndk/reference

    NeuralNetworks
    https://developer.android.com/ndk/reference/group/neural-networks

    1.1 ANEURALNETWORKS_ARGMAX

    Returns the index of the largest element along an axis.
    返回沿轴的最大元素的索引。

    Supported tensor OperandCode:

    ANEURALNETWORKS_TENSOR_FLOAT16
    ANEURALNETWORKS_TENSOR_FLOAT32
    ANEURALNETWORKS_TENSOR_INT32
    ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
    ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED (since NNAPI feature level 4)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    Supported tensor rank: from 1

    Inputs:

    • 0: An n-D tensor specifying the input. Must be non-empty.
    • 1: An ANEURALNETWORKS_INT32 scalar specifying the axis to reduce across. Negative index is used to specify axis from the end (e.g. -1 for the last axis). Must be in the range [-n, n).

    Outputs:

    • 0: An (n - 1)-D ANEURALNETWORKS_TENSOR_INT32 tensor. If input is 1-dimensional, the output shape is [1].

    Note: 输入是 n-D tensor,输出是 (n - 1)-D ANEURALNETWORKS_TENSOR_INT32 tensor。

    Available since NNAPI feature level 3.

    1.2 ANEURALNETWORKS_ARGMIN

    Returns the index of the smallest element along an axis.
    返回沿轴的最小元素的索引。

    Supported tensor OperandCode:

    ANEURALNETWORKS_TENSOR_FLOAT16
    ANEURALNETWORKS_TENSOR_FLOAT32
    ANEURALNETWORKS_TENSOR_INT32
    ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
    ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED (since NNAPI feature level 4)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    Supported tensor rank: from 1

    Inputs:

    • 0: An n-D tensor specifying the input. Must be non-empty.
    • 1: An ANEURALNETWORKS_INT32 scalar specifying the axis to reduce across. Negative index is used to specify axis from the end (e.g. -1 for the last axis). Must be in the range [-n, n).

    Outputs:

    • 0: An (n - 1)-D ANEURALNETWORKS_TENSOR_INT32 tensor. If input is 1-dimensional, the output shape is [1].

    Note: 输入是 n-D tensor,输出是 (n - 1)-D ANEURALNETWORKS_TENSOR_INT32 tensor。

    Available since NNAPI feature level 3.

    2. Paddle

    2.1 argmax

    https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/argmax_cn.html

    This OP computes the indices of the max elements of the input tensor’s element along the provided axis.
    该 OP 沿 axis 计算输入 x 的最大元素的索引。

    Args:
    x (Variable): An input N-D Tensor with type float32, float64, int16, int32, int64, uint8.

    axis (int, optional): Axis to compute indices along. The effective range is [-R, R), where R is Rank(x). when axis < 0, it works the same way as axis + R. Default is 0.
    axis 的有效范围是 [-R, R)R 是输入 xRankaxis 为负时与 axis + R 等价。默认值为 0。

    Returns:
    Variable: A Tensor with data type int64.

    import paddle.fluid as fluid
    import numpy as np
    
    in1 = np.array([[[5,8,9,5],
                    [0,0,1,7],
                    [6,9,2,4]],
                    [[5,2,4,2],
                    [4,7,7,9],
                    [1,7,0,6]]])
    
    with fluid.dygraph.guard():
        x = fluid.dygraph.to_variable(in1)
        out1 = fluid.layers.argmax(x=x, axis=-1)
        out2 = fluid.layers.argmax(x=x, axis=0)
        out3 = fluid.layers.argmax(x=x, axis=1)
        out4 = fluid.layers.argmax(x=x, axis=2)
        print(out1.numpy())
        # [[2 3 1]
        #  [0 3 1]]
        print(out2.numpy())
        # [[0 0 0 0]
        #  [1 1 1 1]
        #  [0 0 0 1]]
        print(out3.numpy())
        # [[2 2 0 1]
        #  [0 1 1 1]]
        print(out4.numpy())
        # [[2 3 1]
        #  [0 3 1]]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    1. axis = -1

    axis = axis + 3 = -1 + 3 = 2
    input.shape = (2, 3, 4),axis = 2,将 input 的 axis = 2 上的维度删除,则 ouput.shape = (2, 3)。

    input 沿着 axis = 2 上最大值的索引。

    [m, n, :] m = 0, 1, ..., M - 1; n = 0, 1, ..., N - 1
    
    [0, 0, :] (= [0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]) 系列坐标上最大值为 9,索引为 2。
    [0, 1, :] (= [0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3]) 系列坐标上最大值为 7,索引为 3。
    [0, 2, :] (= [0, 2, 0], [0, 2, 1], [0, 2, 2], [0, 2, 3]) 系列坐标上最大值为 9,索引为 1。
    [1, 0, :] (= [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3]) 系列坐标上最大值为 5,索引为 0。
    [1, 1, :] (= [1, 1, 0], [1, 1, 1], [1, 1, 2], [1, 1, 3]) 系列坐标上最大值为 9,索引为 3。
    [1, 2, :] (= [1, 2, 0], [1, 2, 1], [1, 2, 2], [1, 2, 3]) 系列坐标上最大值为 7,索引为 1。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    1. axis = 0

    input.shape = (2, 3, 4),axis = 0,将 input 的 axis = 0 上的维度删除,则 ouput.shape = (3, 4)。

    input 沿着 axis = 0 上最大值的索引。

    [:, n, k] n = 0, 1, ..., N - 1; k = 0, 1, ..., K - 1
    
    [:, 0, 0] (= [0, 0, 0], [1, 0, 0]) 系列坐标上最大值为 5,索引为 0。
    [:, 0, 1] (= [0, 0, 1], [1, 0, 1]) 系列坐标上最大值为 8,索引为 0。
    [:, 0, 2] (= [0, 0, 2], [1, 0, 2]) 系列坐标上最大值为 9,索引为 0。
    [:, 0, 3] (= [0, 0, 3], [1, 0, 3]) 系列坐标上最大值为 5,索引为 0。
    [:, 1, 0] (= [0, 1, 0], [1, 1, 0]) 系列坐标上最大值为 4,索引为 1。
    [:, 1, 1] (= [0, 1, 1], [1, 1, 1]) 系列坐标上最大值为 7,索引为 1。
    [:, 1, 2] (= [0, 1, 2], [1, 1, 2]) 系列坐标上最大值为 7,索引为 1。
    [:, 1, 3] (= [0, 1, 3], [1, 1, 3]) 系列坐标上最大值为 9,索引为 1。
    [:, 2, 0] (= [0, 2, 0], [1, 2, 0]) 系列坐标上最大值为 6,索引为 0。
    [:, 2, 1] (= [0, 2, 1], [1, 2, 1]) 系列坐标上最大值为 9,索引为 0。
    [:, 2, 2] (= [0, 2, 2], [1, 2, 2]) 系列坐标上最大值为 2,索引为 0。
    [:, 2, 3] (= [0, 2, 3], [1, 2, 3]) 系列坐标上最大值为 6,索引为 1。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    1. axis = 1

    input.shape = (2, 3, 4),axis = 1,将 input 的 axis = 1 上的维度删除,则 ouput.shape = (2, 4)。

    input 沿着 axis = 1 上最大值的索引。

    [m, :, k] m = 0, 1, ..., M - 1; k = 0, 1, ..., K - 1
    
    [0, :, 0] (= [0, 0, 0], [0, 1, 0], [0, 2, 0]) 系列坐标上最大值为 6,索引为 2。
    [0, :, 1] (= [0, 0, 1], [0, 1, 1], [0, 2, 1]) 系列坐标上最大值为 9,索引为 2。
    [0, :, 2] (= [0, 0, 2], [0, 1, 2], [0, 2, 2]) 系列坐标上最大值为 9,索引为 0。
    [0, :, 3] (= [0, 0, 3], [0, 1, 3], [0, 2, 3]) 系列坐标上最大值为 7,索引为 1。
    [1, :, 0] (= [1, 0, 0], [1, 1, 0], [1, 2, 0]) 系列坐标上最大值为 5,索引为 0。
    [1, :, 1] (= [1, 0, 1], [1, 1, 1], [1, 2, 1]) 系列坐标上最大值为 7,索引为 1。
    [1, :, 2] (= [1, 0, 2], [1, 1, 2], [1, 2, 2]) 系列坐标上最大值为 7,索引为 1。
    [1, :, 3] (= [1, 0, 3], [1, 1, 3], [1, 2, 3]) 系列坐标上最大值为 9,索引为 1。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    1. axis = 2

    input.shape = (2, 3, 4),axis = 2,将 input 的 axis = 2 上的维度删除,则 ouput.shape = (2, 3)。

    input 沿着 axis = 2 上最大值的索引。

    [m, n, :] m = 0, 1, ..., M - 1; n = 0, 1, ..., N - 1
    
    [0, 0, :] (= [0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]) 系列坐标上最大值为 9,索引为 2。
    [0, 1, :] (= [0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3]) 系列坐标上最大值为 7,索引为 3。
    [0, 2, :] (= [0, 2, 0], [0, 2, 1], [0, 2, 2], [0, 2, 3]) 系列坐标上最大值为 9,索引为 1。
    [1, 0, :] (= [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3]) 系列坐标上最大值为 5,索引为 0。
    [1, 1, :] (= [1, 1, 0], [1, 1, 1], [1, 1, 2], [1, 1, 3]) 系列坐标上最大值为 9,索引为 3。
    [1, 2, :] (= [1, 2, 0], [1, 2, 1], [1, 2, 2], [1, 2, 3]) 系列坐标上最大值为 7,索引为 1。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    2.2 argmin

    https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/argmin_cn.html

    This OP computes the indices of the min elements of the input tensor’s element along the provided axis.
    该 OP 沿 axis 计算输入 x 的最小元素的索引。

    Args:
    x (Variable): An input N-D Tensor with type float32, float64, int16, int32, int64, uint8.

    axis (int, optional): Axis to compute indices along. The effective range is [-R, R), where R is Rank(x). when axis < 0, it works the same way as axis + R. Default is 0.
    axis 的有效范围是 [-R, R)R 是输入 xRankaxis 为负时与 axis + R 等价。默认值为 0。

    Returns:
    Variable: A Tensor with data type int64.

    import paddle.fluid as fluid
    import numpy as np
    
    in1 = np.array([[[5,8,9,5],
                    [0,0,1,7],
                    [6,9,2,4]],
                    [[5,2,4,2],
                    [4,7,7,9],
                    [1,7,0,6]]])
    with fluid.dygraph.guard():
        x = fluid.dygraph.to_variable(in1)
        out1 = fluid.layers.argmin(x=x, axis=-1)
        out2 = fluid.layers.argmin(x=x, axis=0)
        out3 = fluid.layers.argmin(x=x, axis=1)
        out4 = fluid.layers.argmin(x=x, axis=2)
        print(out1.numpy())
        # [[0 0 2]
        #  [1 0 2]]
        print(out2.numpy())
        # [[0 1 1 1]
        #  [0 0 0 0]
        #  [1 1 1 0]]
        print(out3.numpy())
        # [[1 1 1 2]
        #  [2 0 2 0]]
        print(out4.numpy())
        # [[0 0 2]
        #  [1 0 2]]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    1. axis = -1

    axis = axis + 3 = -1 + 3 = 2
    input.shape = (2, 3, 4),axis = 2,将 input 的 axis = 2 上的维度删除,则 ouput.shape = (2, 3)。

    input 沿着 axis = 2 上最小值的索引。

    [m, n, :] m = 0, 1, ..., M - 1; n = 0, 1, ..., N - 1
    
    [0, 0, :] (= [0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]) 系列坐标上最小值为 5,索引为 0。
    [0, 1, :] (= [0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3]) 系列坐标上最小值为 0,索引为 0。
    [0, 2, :] (= [0, 2, 0], [0, 2, 1], [0, 2, 2], [0, 2, 3]) 系列坐标上最小值为 2,索引为 2。
    [1, 0, :] (= [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3]) 系列坐标上最小值为 2,索引为 1。
    [1, 1, :] (= [1, 1, 0], [1, 1, 1], [1, 1, 2], [1, 1, 3]) 系列坐标上最小值为 4,索引为 0。
    [1, 2, :] (= [1, 2, 0], [1, 2, 1], [1, 2, 2], [1, 2, 3]) 系列坐标上最小值为 0,索引为 2。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    1. axis = 0

    input.shape = (2, 3, 4),axis = 0,将 input 的 axis = 0 上的维度删除,则 ouput.shape = (3, 4)。

    input 沿着 axis = 0 上最小值的索引。

    [:, n, k] n = 0, 1, ..., N - 1; k = 0, 1, ..., K - 1
    
    [:, 0, 0] (= [0, 0, 0], [1, 0, 0]) 系列坐标上最小值为 5,索引为 0。
    [:, 0, 1] (= [0, 0, 1], [1, 0, 1]) 系列坐标上最小值为 2,索引为 1。
    [:, 0, 2] (= [0, 0, 2], [1, 0, 2]) 系列坐标上最小值为 4,索引为 1。
    [:, 0, 3] (= [0, 0, 3], [1, 0, 3]) 系列坐标上最小值为 2,索引为 1。
    [:, 1, 0] (= [0, 1, 0], [1, 1, 0]) 系列坐标上最小值为 0,索引为 0。
    [:, 1, 1] (= [0, 1, 1], [1, 1, 1]) 系列坐标上最小值为 0,索引为 0。
    [:, 1, 2] (= [0, 1, 2], [1, 1, 2]) 系列坐标上最小值为 1,索引为 0。
    [:, 1, 3] (= [0, 1, 3], [1, 1, 3]) 系列坐标上最小值为 7,索引为 0。
    [:, 2, 0] (= [0, 2, 0], [1, 2, 0]) 系列坐标上最小值为 1,索引为 1。
    [:, 2, 1] (= [0, 2, 1], [1, 2, 1]) 系列坐标上最小值为 7,索引为 1。
    [:, 2, 2] (= [0, 2, 2], [1, 2, 2]) 系列坐标上最小值为 0,索引为 1。
    [:, 2, 3] (= [0, 2, 3], [1, 2, 3]) 系列坐标上最小值为 4,索引为 0。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    1. axis = 1

    input.shape = (2, 3, 4),axis = 1,将 input 的 axis = 1 上的维度删除,则 ouput.shape = (2, 4)。

    input 沿着 axis = 1 上最小值的索引。

    [m, :, k] m = 0, 1, ..., M - 1; k = 0, 1, ..., K - 1
    
    [0, :, 0] (= [0, 0, 0], [0, 1, 0], [0, 2, 0]) 系列坐标上最小值为 0,索引为 1。
    [0, :, 1] (= [0, 0, 1], [0, 1, 1], [0, 2, 1]) 系列坐标上最小值为 0,索引为 1。
    [0, :, 2] (= [0, 0, 2], [0, 1, 2], [0, 2, 2]) 系列坐标上最小值为 1,索引为 1。
    [0, :, 3] (= [0, 0, 3], [0, 1, 3], [0, 2, 3]) 系列坐标上最小值为 4,索引为 2。
    [1, :, 0] (= [1, 0, 0], [1, 1, 0], [1, 2, 0]) 系列坐标上最小值为 1,索引为 2。
    [1, :, 1] (= [1, 0, 1], [1, 1, 1], [1, 2, 1]) 系列坐标上最小值为 2,索引为 0。
    [1, :, 2] (= [1, 0, 2], [1, 1, 2], [1, 2, 2]) 系列坐标上最小值为 0,索引为 2。
    [1, :, 3] (= [1, 0, 3], [1, 1, 3], [1, 2, 3]) 系列坐标上最小值为 2,索引为 0。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    1. axis = 2

    input.shape = (2, 3, 4),axis = 2,将 input 的 axis = 2 上的维度删除,则 ouput.shape = (2, 3)。

    input 沿着 axis = 2 上最小值的索引。

    [m, n, :] m = 0, 1, ..., M - 1; n = 0, 1, ..., N - 1
    
    [0, 0, :] (= [0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]) 系列坐标上最小值为 5,索引为 0。
    [0, 1, :] (= [0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3]) 系列坐标上最小值为 0,索引为 0。
    [0, 2, :] (= [0, 2, 0], [0, 2, 1], [0, 2, 2], [0, 2, 3]) 系列坐标上最小值为 2,索引为 2。
    [1, 0, :] (= [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3]) 系列坐标上最小值为 2,索引为 1。
    [1, 1, :] (= [1, 1, 0], [1, 1, 1], [1, 1, 2], [1, 1, 3]) 系列坐标上最小值为 4,索引为 0。
    [1, 2, :] (= [1, 2, 0], [1, 2, 1], [1, 2, 2], [1, 2, 3]) 系列坐标上最小值为 0,索引为 2。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    3. TensorFlow

    3.1 tf.math.argmax

    https://www.tensorflow.org/api_docs/python/tf/math/argmax

    Returns the index with the largest value across axes of a tensor.
    计算 tensor 沿着某一维度的最大值的索引。

    In case of identity returns the smallest index.
    在相等的情况下返回最小的索引。

    tf.math.argmax(
        input,
        axis=None,
        output_type=tf.dtypes.int64,
        name=None
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    Args:
    input: A Tensor.
    axis: An integer, the axis to reduce across. Default to 0.
    output_type: An optional output dtype (tf.int32 or tf.int64). Defaults to tf.int64.
    name: An optional name for the operation.

    Returns:
    A Tensor of type output_type.

    For example:

      >>> A = tf.constant([2, 20, 30, 3, 6])
      >>> tf.math.argmax(A)  # A[2] is maximum in tensor A
      
    
      >>> B = tf.constant([[2, 20, 30, 3, 6], [3, 11, 16, 1, 8],
      ...                  [14, 45, 23, 5, 27]])
      >>> tf.math.argmax(B, 0)
      
      >>> tf.math.argmax(B, 1)
      
    
      >>> C = tf.constant([0, 0, 0, 0])
      >>> tf.math.argmax(C)  # Returns smallest index in case of ties
      
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    3.2 tf.math.argmin

    https://www.tensorflow.org/api_docs/python/tf/math/argmin

    Returns the index with the smallest value across axes of a tensor.
    计算 tensor 沿着某一维度的最小值的索引。

    Returns the smallest index in case of ties.
    在相等的情况下返回最小的索引。

    tie [taɪ]:n. 领带,联系,纽带,绳子,金属丝,线,关系,束缚,平局,淘汰赛,延音线 v. (用线、绳等) 系,拴,绑,捆,束,将 ... 系在 ... 上,束紧,系牢,捆绑,(在线、绳上) 打结,系扣,打结系牢,连接,束缚,打成平局,用连接线连接
    
    • 1
    tf.math.argmin(
        input,
        axis=None,
        output_type=tf.dtypes.int64,
        name=None
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    Args:
    input: A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, half, uint32, uint64.
    axis: A Tensor. Must be one of the following types: int32, int64. int32 or int64, must be in the range -rank(input), rank(input)). Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
    output_type: An optional tf.DType from: tf.int32, tf.int64. Defaults to tf.int64.
    name: A name for the operation (optional).

    Returns:
    A Tensor of type output_type.

    For example:

      import tensorflow as tf
      a = [1, 10, 26.9, 2.8, 166.32, 62.3]
      b = tf.math.argmin(input = a)
      c = tf.keras.backend.eval(b)
      # c = 0
      # here a[0] = 1 which is the smallest element of a across axis 0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    4. PyTorch

    4.1 TORCH.ARGMAX

    https://pytorch.org/docs/stable/generated/torch.argmax.html

    • torch.argmax(input) → LongTensor

    Returns the indices of the maximum value of all elements in the input tensor.
    返回输入 tensor 中所有元素最大值对应的索引。

    NOTE
    If there are multiple maximal values then the indices of the first maximal value are returned.
    如果有多个最大值,则返回第一个最大值的索引。

    Parameters:
    input (Tensor) - the input tensor.

    Example:

    >>> a = torch.randn(4, 4)
    >>> a
    tensor([[ 1.3398,  0.2663, -0.2686,  0.2450],
            [-0.7401, -0.8805, -0.3402, -1.1936],
            [ 0.4907, -1.3948, -1.0691, -0.3132],
            [-1.6092,  0.5419, -0.2993,  0.3195]])
    >>> torch.argmax(a)
    tensor(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • torch.argmax(input, dim, keepdim=False) → LongTensor

    Returns the indices of the maximum values of a tensor across a dimension.

    Parameters:
    input (Tensor) - the input tensor.
    dim (int) - the dimension to reduce. If None, the argmax of the flattened input is returned.
    keepdim (bool) - whether the output tensor has dim retained or not. Ignored if dim=None.

    Example:

    >>> a = torch.randn(4, 4)
    >>> a
    tensor([[ 1.3398,  0.2663, -0.2686,  0.2450],
            [-0.7401, -0.8805, -0.3402, -1.1936],
            [ 0.4907, -1.3948, -1.0691, -0.3132],
            [-1.6092,  0.5419, -0.2993,  0.3195]])
    >>> torch.argmax(a, dim=1)
    tensor([ 0,  2,  0,  1])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    4.2 TORCH.ARGMIN

    https://pytorch.org/docs/stable/generated/torch.argmin.html

    • torch.argmin(input, dim=None, keepdim=False) → LongTensor

    Returns the indices of the minimum value(s) of the flattened tensor or along a dimension

    NOTE
    If there are multiple minimal values then the indices of the first minimal value are returned.
    如果有多个最小值,则返回第一个最小值的索引。

    Parameters:
    input (Tensor) - the input tensor.
    dim (int) - the dimension to reduce. If None, the argmin of the flattened input is returned.
    keepdim (bool) - whether the output tensor has dim retained or not…

    Example:

    >>> a = torch.randn(4, 4)
    >>> a
    tensor([[ 0.1139,  0.2254, -0.1381,  0.3687],
            [ 1.0100, -1.1975, -0.0102, -0.4732],
            [-0.9240,  0.1207, -0.7506, -1.0213],
            [ 1.7809, -1.2960,  0.9384,  0.1438]])
    >>> torch.argmin(a)
    tensor(13)
    >>> torch.argmin(a, dim=1)
    tensor([ 2,  1,  3,  1])
    >>> torch.argmin(a, dim=1, keepdim=True)
    tensor([[2],
            [1],
            [3],
            [1]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    References

    https://yongqiang.blog.csdn.net/
    https://yongqiang.blog.csdn.net/article/details/121323744

  • 相关阅读:
    TPM零知识学习四 —— tpm2-tss源码安装
    开源 github flow的版本发布
    kafka
    数字孪生|成熟度等级
    黑马头条(day01)
    11种增加访问者在网站上平均停留时间的技巧
    Helm实战案例二:在Kubernetes(k8s)上使用helm安装部署日志管理系统EFK
    中阿科技论坛杂志中阿科技论坛杂志社中阿科技论坛编辑部2022年第7期目录
    Nginx__基础入门篇
    小程序常见操作
  • 原文地址:https://blog.csdn.net/chengyq116/article/details/127991992