码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • Torch知识点总结【持续更新中......】


    文章目录

        • 1、with torch.no_grad()
        • 2、【bug】:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1.
        • 3、argparse.ArgumentParser解析
        • 4、torch和numpy的相互转换
        • 5、如何加载pkl模型文件
          • 5.1 torch.load()加载模型及其map_location参数
        • 6、expand与expand_as
          • 6.1 expand方法
          • 6.2 expand_as方法
        • 7、bug:【TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.】
        • 8、bug:【one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [16, 10, 120, 120]], which is output 0 of ClampBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).】
        • 9、PyTorch中clone()、detach()
        • 10、torch.repeat()
        • 11、Pytorch中index_select()
        • 12、Pytorch中.new()的作用
        • 13、torch.max()
        • 14、torch.ones(),torch.add(),torch.zeros(),torch.squeeze()
        • 15、torch中的grad与backward
          • 15.1 理解optimizer.zero_grad(), loss.backward(), optimizer.step()的作用及原理
        • 16、torch.utils.data.DataLoader
        • 17、transforms.ToTensor和transforms.Normalize
        • 18、pytorch之多GPU使用,nn.DataParallel
        • 19、torch.split
        • 20、torch.contiguous()方法

    1、with torch.no_grad()

    主要有几个重要的点:
    1、torch.no_grad上一个上下文管理器,在你确定不需要调用Tensor.backward()时可以用torch.no_grad来屏蔽梯度计算;
    2、在被torch.no_grad管控下计算得到的tensor,它的requires_grad就是False;

    参考:链接1,这个链接2也不错

    2、【bug】:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1.

    自己总结:主要是维度不匹配
    参考:链接1,链接2

    3、argparse.ArgumentParser解析

    参考:链接1,链接2

    4、torch和numpy的相互转换

    代码:

    import math
    import torch
    import numpy as np
    import pandas as pd
    
    A = np.array([[1,2,3],[6,5,3]])
    print(A, '\n')
    B = torch.from_numpy(A)  #将numpy 转换化为 tensor
    print(B)
    C = B.numpy()#tensor 转换化为 numpy 但是对该numpy进行修改会改变其他的的值
     # 对C修改后 A,B 都会相应的改变
    C[1] = 0
    print(A, '\n')
    print(B)
    print(C)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    输出:

    [[1 2 3]
     [6 5 3]] 
    
    tensor([[1, 2, 3],
            [6, 5, 3]], dtype=torch.int32)
    [[1 2 3]
     [0 0 0]] 
    
    tensor([[1, 2, 3],
            [0, 0, 0]], dtype=torch.int32)
    [[1 2 3]
     [0 0 0]]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    几种转换形式的区别
    代码:

    import torch
    import numpy as np
    
    #创建一个numpy array的数组
    array = np.array([1,2,3,4])
    
    #将numpy array转换为torch tensor
    tensor = torch.tensor(array)
    Tensor = torch.Tensor(array)
    as_tensor = torch.as_tensor(array)
    from_array = torch.from_numpy(array)
    
    print(array.dtype)      #int32
    #查看torch默认的数据类型
    print(torch.get_default_dtype())    #torch.float32
    
    #对比几种不同方法之间的异同
    print(tensor.dtype)     #torch.int32
    print(Tensor.dtype)     #torch.float32
    print(as_tensor.dtype)  #torch.int32
    print(from_array.dtype) #torch.int32
    array[0] = 10
    
    print(tensor)     # tensor([1, 2, 3, 4], dtype=torch.int32)
    print(Tensor)     # tensor([1., 2., 3., 4.])
    print(as_tensor)  #tensor([10,  2,  3,  4], dtype=torch.int32)
    print(from_array) #tensor([10,  2,  3,  4], dtype=torch.int32)
    # 后面两种数据改变,前面不变
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    参考:链接1,链接2

    5、如何加载pkl模型文件

    import pickle
    with open(weights_path, 'rb') as f:
         obj = f.read()
         weights = {key: weight_dict for key, weight_dict in pickle.loads(obj, encoding='latin1').items()}
         model.load_state_dict(weights)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    参考:链接

    5.1 torch.load()加载模型及其map_location参数

    参考:链接1,链接2,这个链接最系统:链接3

    6、expand与expand_as

    6.1 expand方法
    x =torch.tensor([1,2,3,4])
    x.shape
    torch.Size([4])
    
    #x拓展一维,变1x4
    x1 = x.expend(1,4)
    x1
    tensor([[1, 2, 3, 4]])
    x1.shape
    torch.Size([1, 4])
    
    #x1拓展一维,增加2行,变2x4,多加的一行重复原值
    x2 = x1.expend(2,1,4)
    x2
    tensor([[[1, 2, 3, 4]],
            [[1, 2, 3, 4]]])
    torch.Size([2, 1, 4])
    
    #x3拓展一维,增加2行,变为2x2x1x4,多加的一行重复原值
    x3 = x2.expand(2,2,1,4)
    x3 
    tensor([[[[1, 2, 3, 4]],
    
             [[1, 2, 3, 4]]],
    
    
            [[[1, 2, 3, 4]],
    
             [[1, 2, 3, 4]]]])
    torch.Size([2, 2, 1, 4])
    
    #x4直接拓展2个维度,变为2x1x4,
    x4 = x.expand(2,1,4)
    x4
    
    tensor([[[1, 2, 3, 4]],
    
            [[1, 2, 3, 4]]])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    参数为传入指定shape,在原shape数据上进行高维拓维,根据维度值进行重复赋值。

    注意:
    1.只能拓展维度,比如 A的shape为 2x4的,不能 A.expend(1,4),只能保证原结构不变,在前面增维,比如A.shape(1,1,4)
    2.可以增加多维,比如x的shape为(4),x.expend(2,2,1,4)只需保证本身是4
    3.不能拓展低维,比如x的shape为(4),不能x.expend(4,2).
    参考:链接

    6.2 expand_as方法

    参考:链接

    7、bug:【TypeError: can’t convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.】

    参考:链接,这个链接也不错:链接

    8、bug:【one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [16, 10, 120, 120]], which is output 0 of ClampBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).】

    参考:链接1,链接2

    9、PyTorch中clone()、detach()

    参考:链接1,链接2

    10、torch.repeat()

    参考:链接1

    11、Pytorch中index_select()

    参考:链接1

    12、Pytorch中.new()的作用

    参考:链接1

    13、torch.max()

    参考:链接1

    14、torch.ones(),torch.add(),torch.zeros(),torch.squeeze()

    参考:链接1

    15、torch中的grad与backward

    参考:链接1

    15.1 理解optimizer.zero_grad(), loss.backward(), optimizer.step()的作用及原理

    参考:链接1

    16、torch.utils.data.DataLoader

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, num_workers=16, shuffle=True,
                                              pin_memory=True, drop_last=True)
    
    • 1
    • 2

    这里trainset需要满足2种格式中的其中一种即可,这里介绍其中一种格式,trainset对象需要具有__getitem__()和__len__()2种实例方法,下图是官方解释截图,参考:链接
    在这里插入图片描述

    参考:链接1,链接2,链接3

    17、transforms.ToTensor和transforms.Normalize

    参考:链接1

    18、pytorch之多GPU使用,nn.DataParallel

    参考:链接1,链接2

    19、torch.split

    参考:csdn链接1

    20、torch.contiguous()方法

    参考:链接1

  • 相关阅读:
    数字图像处理—python
    接收请求参数及数据回显
    嵌入式行业有无年龄危机?算不算青春饭?
    Hazelcast系列(六): TCP-IP发现机制
    pytorch深度学习快速入门
    Docker容器只有JRE没有JDK使用Jattach导出内存快照
    Vue 入门
    GIS原理篇 地图投影
    如何用SSH克隆GitHub项目
    JAVA毕业设计课设源码分享50+例
  • 原文地址:https://blog.csdn.net/qq_23022733/article/details/126745081
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号