• 神经网络(十六)Pytorch实机运行的一些细节


    一、Dataset的参数

            dataset使用时transform的参数要由ToTensor()修改为torchvision.transforms.ToTensor()。完整函数如下:

    1. train_data=torchvision.datasets.CIFAR10("../data",train=True,
    2. transform=torchvision.transforms.ToTensor(),download=False)

            训练集同理

    二、网络模型

            ①线性层的参数格式

    nn.Linear(64*4*4,64),     #卷积核(三围),卷积核尺寸

            ②序列器的使用

    1. self.model = nn.Sequential(#内容)
    2. x = self.model(x) #调用

    三、Cuda的使用

            Cuda一旦使用,损失函数、数据、神经网络必须同时运行在cuda上,否则将会报错

            ①检测Cuda设备

    1. if torch.cuda.is_available():
    2. print("检测到了cuda")

            ②将数据转换为cuda

                    将数据转换为cuda类型有以下两种方法

    1. imgs.to('cuda') #或者.to('cpu')
    2. imgs.cuda()

            ③损失函数的cuda调用

    1. loss_fn=nn.CrossEntropyLoss()
    2. loss_fn=loss_fn.cuda() #将其转换为cuda
    3. loss=loss_fn(outputs.to('cuda'),targes.to('cuda')) #调用

            ④神经网络的cuda调用

    1. mynet=MyNerNet().cuda() #将网络cuda实例化
    2. outputs=mynet(imgs.to('cuda')) #使用cuda网络

            ⑤将cuda类型回调

    imgs.to('cpu')    #在将其用于其他用途时需要先转换回cpu型

    四、完整示例

    1. #网络模型Model.py
    2. from msilib import sequence
    3. from turtle import forward
    4. import torch.nn as nn #神经网络库
    5. import torch.nn.functional as F #函数库
    6. class MyNerNet(nn.Module):
    7. def __init__(self):
    8. super(MyNerNet,self).__init__()
    9. self.model = nn.Sequential(
    10. nn.Conv2d(3,32,5,1,2), #卷积
    11. nn.MaxPool2d(2), #池化
    12. nn.Conv2d(32,32,5,1,2,),
    13. nn.MaxPool2d(2),
    14. nn.Conv2d(32,64,5,1,2),
    15. nn.MaxPool2d(2),
    16. nn.Flatten(), #展平
    17. nn.Linear(64*4*4,64), #线性层
    18. nn.Linear(64,10) #线性分类器
    19. )
    20. def forward(self,x): #传递函数
    21. x = self.model(x)
    22. return x
    1. #网络调用
    2. from pickletools import optimize
    3. import torch #导入torch库
    4. import torchvision #导入图像处理库
    5. from torch.utils.data import DataLoader as Loader #加载器
    6. from Model import * #加载自建模型
    7. import torch.nn as nn #引入神经网络支持
    8. train_data=torchvision.datasets.CIFAR10("E:\\CxxDemo\\Python\\Cnn_AlexNet_Test\\data",train=True,transform=torchvision.transforms.ToTensor(),download=False) #训练集dataset
    9. test_data=torchvision.datasets.CIFAR10("E:\\CxxDemo\\Python\\Cnn_AlexNet_Test\\data",train=False,transform=torchvision.transforms.ToTensor(),download=False) #测试集dataset
    10. train_data_size=len(train_data)
    11. test_data_size=len(test_data)
    12. train_loader = Loader(train_data,batch_size=64,shuffle=True) #训练集加载器
    13. test_loader = Loader(test_data,batch_size=64,shuffle=True) #测试集加载器
    14. mynet=MyNerNet().cuda() #将网络实例化
    15. #构建损失函数
    16. loss_fn=nn.CrossEntropyLoss() #交叉损失函数
    17. loss_fn=loss_fn.cuda() #调用cuda
    18. learing_rate=0.01 #学习率
    19. optimizer = torch.optim.SGD(mynet.parameters(),lr=learing_rate) #随机梯度下降
    20. #设置计数器
    21. total_train_step=0 #训练次数
    22. total_test_step=0 #测试次数
    23. epoch=10 #训练轮次
    24. #开始训练
    25. for i in range(epoch):
    26. print("---第{}论训练".format(i+1))
    27. # mynet.train() #开始训练!!!
    28. for data in train_loader:
    29. imgs,targes=data #拆包
    30. outputs=mynet(imgs.to('cuda')) #使用网络
    31. loss=loss_fn(outputs.to('cuda'),targes.to('cuda')) #计算损失函数
    32. optimizer.zero_grad() #梯度清零
    33. loss.backward() #前向传递
    34. optimizer.step() #逐步优化
    35. total_train_step+=1 #训练计数
    36. print("训练次数:{},Loss:{}".format(total_train_step,loss.item()))
    37. #开始测试
    38. total_test_loss = 0 #总损失函数计数
    39. with torch.no_grad(): #不设置梯度(保证不进行调优)
    40. for data in test_loader:
    41. imgs,targets = data #拆包
    42. outputs = mynet(imgs.to('cuda')) #使用网络
    43. loss = loss_fn(outputs.to('cuda'),targets.to('cuda')) #计算损失函数
    44. total_test_loss = total_test_loss + loss #添加此次部分损失函数
    45. print("整个测试集上的Loss:{}".format(total_test_loss))
    46. total_test_step = total_test_step + 1
    47. #保存每轮的模型
    48. #torch.save(mynet,"MyNerNet_Ver{}.pth".format(total_train_step))

             调用结果如下

  • 相关阅读:
    Hudi Spark源码学习总结-spark.read.format(“hudi“).load(2)
    RustChinaConf 2024(Rust中国大会2024)号集令
    C语言——计算数组长度
    Java修仙传之神奇的ES(基础使用)
    如今市面上有什么冷门生意可做
    Django(21):使用Celery任务框架
    俄罗斯套娃 (Matryoshka) 嵌入模型概述
    学习 vite + vue3 + pinia + ts(四)setup异步返回 async setup
    CesiumJS【Basic】- #024 加载webm文件(Primitive方式)
    c#中使用Task.WhenAll
  • 原文地址:https://blog.csdn.net/weixin_37878740/article/details/127414854