• nn.PairwiseDistance 和 torch.cdist 和 MSELoss 计算距离


    下面是三种计算0

    nn.PairwiseDistance 和 torch.cdist

    L1 距离

    • 一维度L1距离
    
    output=torch.arange(1, 6).float()
    target=torch.ones_like(output).float()
    In [41]: output
    Out[41]: tensor([2., 3., 4., 5., 6., 7.])
    
    In [42]: target
    Out[42]: tensor([1., 1., 1., 1., 1., 1.])
    
    # pair wise dist p1
    In [13]: pdist = nn.PairwiseDistance(p=1)
    In [16]: p_dist=pdist(output.reshape(-1, 1), target.reshape(-1, 1))
    
    In [17]: p_dist # 约等于 torch.abs(target) - torch.abs(output)
    Out[17]: tensor([1.0000e-06, 1.0000e+00, 2.0000e+00, 3.0000e+00, 4.0000e+00])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    p_dist是怎么来的呢?就是公式:

    在这里插入图片描述
    里面的 x_i 表示 输入nn.PairwiseDistance的两个vector的每一行的距离。
    n 表示 embedding的dim,如果这个embedding是2维的,则是一个vector,而vector的每一行表示一个embedding, vector的列的数目表示embedding的dim。

    output=torch.arange(1, 6).float()
    target=torch.ones_like(output).float()
    In [41]: output
    Out[41]: tensor([2., 3., 4., 5., 6., 7.])
    
    In [42]: target
    Out[42]: tensor([1., 1., 1., 1., 1., 1.])
    
    In [105]: torch.cdist(output.reshape(-1, 1), target.reshape(-1, 1), p=1)
    Out[105]:
    tensor([[0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 1.],
            [2., 2., 2., 2., 2.],
            [3., 3., 3., 3., 3.],
            [4., 4., 4., 4., 4.]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    可以看到对角线就是这个 pairwise distance(这个例子还看的不是很清楚,可以看后面的例子,是对角线没错)

    • 高维度L1距离
    In [106]: a=torch.Tensor([1,2,3])
    
    In [107]: a=torch.Tensor([1,2,3]).reshape(1, -1)
    
    In [108]: b=torch.Tensor([4,5,6]).reshape(1, -1)
    
    In [113]: pdist = nn.PairwiseDistance(p=1)
    
    In [114]: pdist(a, b)
    Out[114]: tensor([9.0000])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    对于高纬度向量,比如这里的3维度的两个vector。
    对应如下计算过程:

     (1-4)^1 + (2-5)^1 + (3-6)^1  =  9
    
    • 1

    更多的例子:

    In [54]: a=torch.randn(4,2)
    
    In [55]: b=torch.randn(4,2)
    
    In [67]: torch.cdist(a, b, p=1)
    Out[67]:
    tensor([[2.2387, 2.4500, 0.6863, 2.8237],
            [1.4554, 1.6667, 1.3217, 3.3143],
            [2.9661, 3.1774, 0.9236, 4.4335],
            [1.4922, 0.8932, 3.2191, 3.7818]])
            
    In [65]: pdist = nn.PairwiseDistance(p=1)
    
    In [66]: pdist(a, b)
    Out[66]: tensor([2.2387, 1.6667, 0.9236, 3.7818])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    可以看到 PairwiseDistance 是逐个元素相减。
    cdist对角线的值跟 PairwiseDistance 得到的距离一致。

    L2 距离

    • 一维度L2距离
    import torch
    import torch.nn as nn
     
    output=torch.arange(1, 6).float()
    target=torch.ones_like(output).float()
    
    In [41]: output
    Out[41]: tensor([2., 3., 4., 5., 6., 7.])
    
    In [42]: target
    Out[42]: tensor([1., 1., 1., 1., 1., 1.])
    
    # pair wise dist p2
    In [22]: pdist = nn.PairwiseDistance(p=2)
        ...: p_dist=pdist(output.reshape(-1, 1), target.reshape(-1, 1))
    
    In [23]: p_dist # 
    Out[23]: tensor([1.0000e-06, 1.0000e+00, 2.0000e+00, 3.0000e+00, 4.0000e+00])
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    可以看到 PairwiseDistance 是逐个元素相减。

    • 高维度L2距离
    In [106]: a=torch.Tensor([1,2,3])
    
    In [107]: a=torch.Tensor([1,2,3]).reshape(1, -1)
    
    In [108]: b=torch.Tensor([4,5,6]).reshape(1, -1)
    
    In [111]: pdist = nn.PairwiseDistance(p=2)
    
    In [112]: pdist(a, b)
    Out[112]: tensor([5.1962])
    
    In [116]: torch.cdist(a, b, p=2)
    Out[116]: tensor([[5.1962]])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    对于高纬度向量,比如这里的3维度的两个vector。
    对应如下计算过程:

    sqr( (1-4)^2 + (2-5)^2 + (3-6)^2 ) =  sqr( 27 ) = 5.1962
    
    • 1

    对于两个(1, dim)的embedding,得到的distance也是1。

    ## example 2
    a=torch.randn(4,2)
    
    b=torch.randn(4,2)
    
    In [68]: a
    Out[68]:
    tensor([[ 0.1118, -0.4733],
            [-0.5251, -0.3270],
            [-0.3294, -1.6420],
            [-0.9775,  1.1180]])
    
    In [69]: b
    Out[69]:
    tensor([[-1.5026,  0.1509],
            [-1.3087,  0.5561],
            [ 0.0379, -1.0857],
            [ 2.0742,  0.3880]])
    
    In [56]: pdist = nn.PairwiseDistance(p=2)
    
    In [57]: pdist(a, b).shape
    Out[57]: torch.Size([4])
    
    In [60]: pdist(a, b)
    Out[60]: tensor([1.7309, 1.1806, 0.6666, 3.1378])
    
    In [58]: torch.cdist(a, b, p=2)
    Out[58]:
    tensor([[1.7309, 1.7543, 0.6168, 2.1431],
            [1.0881, 1.1806, 0.9448, 2.6959],
            [2.1426, 2.4063, 0.6666, 3.1461],
            [1.1005, 0.6523, 2.4264, 3.1378]])
    
    In [59]: torch.cdist(a, b, p=2).shape
    Out[59]: torch.Size([4, 4])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    可以看到torch.cdist对角线的值跟 PairwiseDistance 得到的距离一致。
    torch.cdist 更多的是计算矩阵乘法的距离。PairwiseDistance 计算的是两个相同大小的 embedding 对应元素的距离。

    torch.cdist 和 torch.mm

    torch.cdist是距离,没有方向的,所有值都是正值。
    torch.mm是矩阵乘法,是有方向的,所以会有负值存在。

    torch.cdist 和 MSELoss

    
    In [187]: a = torch.randn(3, 5)
    
    In [188]: b = torch.randn(3, 5)
    
    In [196]: a
    Out[196]:
    tensor([[-1.7443, -0.2031, -1.0496, -1.1783,  1.1776],
            [ 0.6324,  1.8428,  0.0322, -2.3453,  1.2462],
            [ 0.2345,  0.4086, -0.3005, -0.4243, -0.2004]])
    
    In [197]: b
    Out[197]:
    tensor([[ 1.7902, -1.4599,  0.0531, -0.2376,  0.9399],
            [ 0.4925,  0.6450, -2.2556,  0.4855,  0.4957],
            [-2.0035,  0.2681, -0.1062, -1.5349,  0.2545]])
    
    In [198]: mse = torch.nn.MSELoss(reduction='none')
    
    In [199]: mse(a, b)
    Out[199]:
    tensor([[12.4926,  1.5796,  1.2161,  0.8849,  0.0565],
            [ 0.0196,  1.4347,  5.2341,  8.0137,  0.5633],
            [ 5.0090,  0.0197,  0.0377,  1.2335,  0.2069]])
    
    In [201]: (1.7443+1.7902)
    Out[201]: 3.5345
    
    In [202]: 3.5345*3.5345
    Out[202]: 12.492690249999999
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    可以看到 MSELoss是直接把对应元素的对应维度相减,然后计算平方;
    nn.PairwiseDistance(p=2) 是把对应元素的对应维度相减、求平方,所有维度求和,再开方。

    In [191]: mse = torch.nn.MSELoss(reduction = 'sum')
    
    In [193]: mse(a, b)
    Out[193]: tensor(38.0021)
    
    In [208]: mse = torch.nn.MSELoss(reduction='none')
    
    In [209]: mse(a, b).sum()
    Out[209]: tensor(38.0021)
    
    In [230]: torch.cdist(a, b, p=2)
    Out[230]:
    tensor([[4.0286, 3.2266, 1.4692],
            [4.0970, 3.9071, 3.3298],
            [2.7151, 2.2929, 2.5509]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    MSELoss对于同一个元素的不同维度之间,没用考虑不同维度的关联性。

    PairwiseDistance 和 MSELoss

    In [187]: a = torch.randn(3, 5)
    
    In [188]: b = torch.randn(3, 5)
    
    In [198]: mse = torch.nn.MSELoss(reduction='none')
    
    In [224]: mse(a, b).sum(1)
    Out[224]: tensor([16.2298, 15.2654,  6.5069])
    
    In [221]: pdist = nn.PairwiseDistance(p=2)
    In [225]: pdist(a, b)
    Out[225]: tensor([4.0286, 3.9071, 2.5509])
    
    In [228]: 4.0286*4.0286
    Out[228]: 16.22961796
    
    In [229]: 3.9071*3.9071
    Out[229]: 15.265430409999999
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    可以看到他们是有关系的

  • 相关阅读:
    [数据集][目标检测]狗狗表情识别VOC+YOLO格式3971张4类别
    胶囊网络深入理解
    软件测试肖sir__python之sys模块
    当你访问一个网页时,后台做了些什么?
    项目部署-jenkins
    HiveSQL分位数函数percentile()使用详解+实例代码
    ClassUtils.getClassFileName()方法具有什么功能呢?
    Linux信号(处理)
    【机器学习算法】穿越神经网络的迷雾:深入探索机器学习的核心算法
    微信公众号添加Word文档附件教程_公众号添加Excel、PDF、PPT、Zip等附件教程
  • 原文地址:https://blog.csdn.net/Shirelle_/article/details/125528489