• Python当中的repeat()函数和repeat_interleave()函数剖析和对比


    最近在学习沐神的d2l的时候,深受其中代码的折磨,有些函数真的是从来没见过,组合起来更是让人头皮发麻,根本看不懂代码在写些什么。

    写这篇文章,主要是为了总结一下Python当中的repeat()函数和repeat_interleave()函数,这两个函数在应用于Pytorch和Numpy数组的时候得到的结果也是不一样的,所以有很大的槽点需要注意!

    首先是总结应用于Pytorch领域的repeat()函数和repeat_interleave()函数:

    1.repeat()

    话不多说,直接上代码:

    import torch
    
    # 创建一个张量
    original_tensor = torch.tensor([[1, 2], [3, 4]])
    
    # 沿着行和列方向分别重复张量
    repeated_tensor = original_tensor.repeat(2, 3)
    print(repeated_tensor)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    输出为:

    tensor([[1, 2, 1, 2, 1, 2],
            [3, 4, 3, 4, 3, 4],
            [1, 2, 1, 2, 1, 2],
            [3, 4, 3, 4, 3, 4]])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5

    不难从输出当中得出结论:.repeat(2, 3)就是沿着第一个维度(行)重复 2 次,沿着第二个维度(列)重复 3 次,最终生成了一个 4x6 的张量。注意repeat是一组元素一组元素地重复,这与下面的repeat_interleave()函数是不相同的。

    2.repeat_interleave()

    该函数与repeat()函数的区别在于,它是沿着指定的维度复制张量元素

    ①不指定dim,重复次数为2次,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复2次,并返回重复后的张量。

    a = torch.randn(3,2)
    a,a.repeat_interleave(2)
    
    • 1
    • 2

    输出为:

    (tensor([[-1.03, -0.32],
             [ 0.43,  0.78],
             [ 0.91, -0.11]]),
     tensor([-1.03, -1.03, -0.32, -0.32,  0.43,  0.43,  0.78,  0.78,  0.91,  0.91,
             -0.11, -0.11]))
    
    • 1
    • 2
    • 3
    • 4
    • 5

    ②输入二维张量,指定dim=0,重复次数为3次,表示把输入张量每行元素重复3次

    a = torch.randn(3,2)
    a,torch.repeat_interleave(a,3,dim=0)
    
    • 1
    • 2

    输出为:

    (tensor([[ 0.14,  1.47],
             [-1.52, -0.62],
             [-0.24, -0.27]]),
     tensor([[ 0.14,  1.47],
             [ 0.14,  1.47],
             [ 0.14,  1.47],
             [-1.52, -0.62],
             [-1.52, -0.62],
             [-1.52, -0.62],
             [-0.24, -0.27],
             [-0.24, -0.27],
             [-0.24, -0.27]]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    ③输入二维张量,指定dim=1,重复次数为3次,表示把输入张量每列元素重复3次

    a = torch.randn(3,2)
    a,torch.repeat_interleave(a,3,dim=1)
    
    • 1
    • 2

    输出为:

    (tensor([[-0.81,  0.56],
             [-2.41, -0.56],
             [ 0.38, -0.90]]),
     tensor([[-0.81, -0.81, -0.81,  0.56,  0.56,  0.56],
             [-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
             [ 0.38,  0.38,  0.38, -0.90, -0.90, -0.90]]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    ④输入二维张量,指定dim=0,重复次数为一个张量列表[n1,n2,n3],表示在(dim=0)对应行上面重复n1,n2,n3遍,张量列表的长度必须与dim=0的维度的长度一样,否则会报错

    a = torch.randn(3,2)
    a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)
    #表示第一行重复2遍,第二行重复3遍,第三行重复4遍
    
    • 1
    • 2
    • 3

    输出为:

    (tensor([[-0.79,  0.54],
             [-0.47, -0.25],
             [-0.13,  1.03]]),
     tensor([[-0.79,  0.54],
             [-0.79,  0.54],
             [-0.47, -0.25],
             [-0.47, -0.25],
             [-0.47, -0.25],
             [-0.13,  1.03],
             [-0.13,  1.03],
             [-0.13,  1.03],
             [-0.13,  1.03]]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    总结:可以看出,两个函数方法最大的区别就是repeat_interleave是一个元素一个元素地重复,而repeat是一组元素一组元素地重复

    那到这里就完了吗?完全没有!经过测试发现,以上都是repeat()函数和repeat_interleave()函数应用于pytorch的tensor张量,但当它们应用于numpy数组时,结果又是不一样的!

    例如:

    test_array = torch.arange(9).reshape(3, 3)
    print('采用torch tensor原始:\n', test_array)
    print('采用torch tensor的repeat函数:\n', test_array.repeat(2, 1))
    print('采用torch tensor的repeat_interleave函数:\n', test_array.repeat_interleave(2, dim=0))
    test_array2 = np.arange(9).reshape(3, 3)
    print('采用numpy array原始:\n', test_array2)
    print('采用numpy array的repeat函数:\n', test_array2.repeat(2, 1))
    print('采用numpy array的repeat_interleave函数:\n', test_array2.repeat_interleave(2, dim=0))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    我们运行上述代码,看看结果怎么样:

    采用torch tensor原始:
     tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    采用torch tensor的repeat函数:
     tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8],
            [0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    采用torch tensor的repeat_interleave函数:
     tensor([[0, 1, 2],
            [0, 1, 2],
            [3, 4, 5],
            [3, 4, 5],
            [6, 7, 8],
            [6, 7, 8]])
    采用numpy array原始:
     [[0 1 2]
     [3 4 5]
     [6 7 8]]
    采用numpy array的repeat函数:
     [[0 0 1 1 2 2]
     [3 3 4 4 5 5]
     [6 6 7 7 8 8]]
    Traceback (most recent call last):
      File "D:/PythonProject/DiveIntoDeepLearning(LiMu)/main.py", line 82, in <module>
        print('采用numpy array的repeat_interleave函数:\n', test_array2.repeat_interleave(2, dim=0))
    AttributeError: 'numpy.ndarray' object has no attribute 'repeat_interleave'
    
    Process finished with exit code 1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    从输出结果可以得出以下结论:

    ①pytorch当中的numpy.repeat(2, 1)是指在第一个维度(行)上复制两次,在第二个维度(列)上复制1次,并且是一组元素一组元素地复制;Numpy当中的.repeat(2, 1)是指在第二个维度上(列,对应dim值为1)复制两次,并且是一个一个元素的复制

    ②numpy没有repeat_interleave函数

  • 相关阅读:
    SpringBoot 集成 kaptcha 验证码
    牛客网刷题记录 || 指针
    大数据课程L8——网站流量项目的SparkStreaming整合代码
    第1章 初识MyBatis框架
    Spark面试题(二)
    【动态规划-简单】剑指 Offer 10- II. 青蛙跳台阶问题
    微信推送平台-测试号定制推送
    (经典dp) I型 L型 铺盖2*n
    android Seekbar当点击的时候有一个圆圈
    Vue3中jsx父子传值、provide和inject、v-memo指令、Teleport内置组件、KeepAlive缓存组件、transition过渡组件
  • 原文地址:https://blog.csdn.net/m0_57317650/article/details/134477039