• pytorch张量创建、张量复制


    pytorch张量创建、张量复制

    首先注意一点:在torch中,可导张量计算出的新张量也是可导的,新张量与原张量具有可导连接,那么原张量就不是叶子张量,新张量成了叶子张量。

    创建方式一:torch.tensor()

    torch.tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) → Tensor

    torch.tensor只能从指定的数据创建,但是可以指定数据属性,是否可微分等属性。pin_memory是将张量放置到锁业内存中,所以这个张量只能被cpu使用。

    import torch
    a = [1, 2, 3]
    b = torch.tensor(a, requires_grad=True, dtype=torch.float64)
    
    • 1
    • 2
    • 3
    创建方式二:torch.Tensor

    按照形状创建,如果输入列表,就按照指定数据创建。

    整数:torch.ShortTensor 16位,torch.IntTensor 32位,torch.LongTensor 64位

    浮点:torch.FloatTensor=torch.Tensor 32位,torch.DoubleTensor 64位

    注意:torch.Tensor(int1, int2,int3)会创建[int1, int2,int3]形状的张量,如果传入列表元组等,就会返回该列表元组张量。

    import torch
    torch.Tensor(3) 
    '''tensor([-2.6853e+05,  1.9983e-42,  2.3694e-38])'''
    torch.Tensor(3, 1) 
    '''
    tensor([[3.2842e-15],
            [3.1714e+00],
            [2.3694e-38]])
    '''
    torch.Tensor([3, 1])
    '''
    tensor([3., 1.])
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    同设备内复制 - tensor.data /tensor.detach()/tensor.clone的区别

    这三个单独会用都会和原张量有牵扯:

    1. tensor.data和tensor.detach():随着原张量的数值变化而变化。剥离开了原张量的微分图。
    2. tensor.clone() : 还处于原张量的微分图中。复制了原张量的数值。也就是tesnor.clone().bachward()后,原张量的微分图会进行一次反向传导。
    3. 完全没牵扯:tensor.clone().detach()

    举例:

    import torch
    
    a = torch.tensor(1, requires_grad=True, dtype=torch.float32)
    b = a * 2
    
    b_data = b.data
    b_detach = b.detach()
    b_clone = b.clone()
    print(b, b_data, b_detach, b_clone)
    '''
    tensor(2., grad_fn=) tensor(2.) tensor(2.) tensor(2., grad_fn=)
    '''
    # 当其中一个改变时,tensor.data, tensor.detach也会改变。tensor.clone不会改变。
    b_detach.zero_()
    print(b, b_data, b_detach, b_clone)
    '''
    tensor(0., grad_fn=) tensor(0.) tensor(0.) tensor(2., grad_fn=)
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    当tensor.detach或者tensor.data改变数值时,并不会影响原张量的微分传导结果。

    import torch
    
    a = torch.tensor(1, requires_grad=True, dtype=torch.float32)
    b = a * 2
    
    b_data = b.data
    b_detach = b.detach()
    b_clone = b.clone()
    
    # a的微分结果不受影响
    b_detach.zero_()
    b.backward(retain_graph=True)
    print(a.grad)
    
    # 如果原张量本身变化,则会受到影响。
    b.zero_()
    a.grad.zero_()
    b.backward()
    print(a.grad)
    '''
    tensor(2.)
    tensor(0.)
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    tensor.clone会保持原张量的微分传导图,并会叠加到结果上。

    import torch
    
    a = torch.tensor(1, requires_grad=True, dtype=torch.float32)
    b = a * 2
    
    b_clone = b.clone()
    
    b.backward(retain_graph=True)
    print(a.grad)
    b_clone.backward()
    print(a.grad)
    '''
    tensor(2.)
    tensor(4.)
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    跨设备复制

    方法很多,实际使用就用以下这种:

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    temp = torch.tensor(2)
    temp.to(deivce) # 如果有gpu就放到gpu.
    temp = temp.cpu() # 复制到cpu上
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    C#上位机系列(1)—项目的建立
    系列四、FileReader和FileWriter
    如何在复现LaneNet车道线检测项目时,采用网上博主制作数据集的方法来只做自己的数据集,当把此数据集投喂进网络训练时(采用Pytorch库)会报如下的错误?
    考虑阶梯式碳交易机制与电制氢的综合能源系统热电优化(Matlab代码实现)
    Java百题大战
    java街边熟食店卤菜网上商城系统springboot+vue
    跟着李老师学线代——矩阵(持续更新)
    JAVA 0基础 基本数据类型之间的互相转换
    微服务OR单体架构
    大数据技术基础实验十五:Storm实验——实时WordCountTopology
  • 原文地址:https://blog.csdn.net/Akun_2217/article/details/136331882