码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • python模型训练


    目录

    1、新建模型   train_model.py

    2、运行模型

    (1)首先会下载data文件库

    (2)完成之后会开始训练模型(10次)

    3、 训练好之后,进入命令集

     4、输入命令:python -m tensorboard.main --logdir="C:\Users\15535\Desktop\day6\train"

    (1)目录的绝对路径获得方法

     5、打开网页可视化图形

    (1)运行完之后会自动有一个网址,点进去

     (2)显示


    1、新建模型   train_model.py

    1. import torch
    2. import torchvision.transforms
    3. from torch.utils.tensorboard import SummaryWriter
    4. from torchvision import datasets
    5. from torch.utils.data import DataLoader
    6. import torch.nn as nn
    7. from torch.nn import CrossEntropyLoss
    8. #step1.下载数据集
    9. train_data=datasets.CIFAR10('./data',train=True,\
    10. transform=torchvision.transforms.ToTensor(),
    11. download=True)
    12. test_data=datasets.CIFAR10('./data',train=False,\
    13. transform=torchvision.transforms.ToTensor(),
    14. download=True)
    15. print(len(train_data))
    16. print(len(test_data))
    17. #step2.数据集打包
    18. train_data_loader=DataLoader(train_data,batch_size=64,shuffle=False)
    19. test_data_loader=DataLoader(test_data,batch_size=64,shuffle=False)
    20. #step3.搭建网络模型
    21. class My_Module(nn.Module):
    22. def __init__(self):
    23. super(My_Module,self).__init__()
    24. #64*32*32*32
    25. self.conv1=nn.Conv2d(in_channels=3,out_channels=32,\
    26. kernel_size=5,padding=2)
    27. #64*32*16*16
    28. self.maxpool1=nn.MaxPool2d(2)
    29. #64*32*16*16
    30. self.conv2=nn.Conv2d(in_channels=32,out_channels=32,\
    31. kernel_size=5,padding=2)
    32. #64*32*8*8
    33. self.maxpool2=nn.MaxPool2d(2)
    34. #64*64*8*8
    35. self.conv3=nn.Conv2d(in_channels=32,out_channels=64,\
    36. kernel_size=5,padding=2)
    37. #64*64*4*4
    38. self.maxpool3=nn.MaxPool2d(2)
    39. #线性化
    40. self.flatten=nn.Flatten()
    41. self.linear1=nn.Linear(in_features=1024,out_features=64)
    42. self.linear2=nn.Linear(in_features=64,out_features=10)
    43. def forward(self,input):
    44. #input:64,3,32,32
    45. output1=self.conv1(input)
    46. output2=self.maxpool1(output1)
    47. output3=self.conv2(output2)
    48. output4=self.maxpool2(output3)
    49. output5=self.conv3(output4)
    50. output6=self.maxpool3(output5)
    51. output7=self.flatten(output6)
    52. output8=self.linear1(output7)
    53. output9=self.linear2(output8)
    54. return output9
    55. my_model=My_Module()
    56. # print(my_model)
    57. loss_func=CrossEntropyLoss()#衡量模型训练的过程(输入输出之间的差值)
    58. #优化器,lr越大模型就越“聪明”
    59. optim = torch.optim.SGD(my_model.parameters(),lr=0.001)
    60. writer=SummaryWriter('./train')
    61. #################################训练###############################
    62. for looptime in range(10): #模型训练的次数:10
    63. print("------looptime:{}------".format(looptime+1))
    64. num=0
    65. loss_all=0
    66. for data in (train_data_loader):
    67. num+=1
    68. #前向
    69. imgs, targets = data
    70. output = my_model(imgs)
    71. loss_train = loss_func(output,targets)
    72. loss_all=loss_all+loss_train
    73. if num%100==0:
    74. print(loss_train)
    75. #后向backward 三步法 获取最小的损失函数
    76. optim.zero_grad()
    77. loss_train.backward()
    78. optim.step()
    79. # print(output.shape)
    80. loss_av=loss_all/len(test_data_loader)
    81. print(loss_av)
    82. writer.add_scalar('train_loss',loss_av,looptime)
    83. writer.close()
    84. #################################验证#########################
    85. with torch.no_grad():
    86. accuracy=0
    87. test_loss_all=0
    88. for data in test_data_loader:
    89. imgs,targets = data
    90. output = my_model(imgs)
    91. loss_test = loss_func(output,targets)
    92. #output.argmax(1)---输出标签
    93. accuracy=(output.argmax(1)==targets).sum()
    94. test_loss_all = test_loss_all+loss_test
    95. test_loss_av = test_loss_all/len(test_data_loader)
    96. acc_av = accuracy/len(test_data_loader)
    97. print("测试集的平均损失{},测试集的准确率{}".format(test_loss_av,acc_av))
    98. writer.add_scalar('test_loss',test_loss_av,looptime)
    99. writer.add_scalar('acc',acc_av,looptime)
    100. writer.close()

    2、运行模型

    (1)首先会下载data文件库

    (2)完成之后会开始训练模型(10次)

    3、 训练好之后,进入命令集

     4、输入命令:python -m tensorboard.main --logdir="C:\Users\15535\Desktop\day6\train"

    (1)目录的绝对路径获得方法

    执行下面的操作自动复制

     

     

     5、打开网页可视化图形

    (1)运行完之后会自动有一个网址,点进去

     (2)显示

  • 相关阅读:
    HTML网页设计制作——初音动漫(6页) dreamweaver作业静态HTML网页设计模板
    内存操作函数(memcpy、memmove、memset、memcmp)---- C语言
    CSS中主要定位方式
    基于单片机的北斗定位无人机救火系统(两种程序:单片机与android系统app程序源码)
    EasyCVR平台如何实现超低延时的安防视频监控直播?
    将字符串转换为小写形式字符串.casefold()
    LintCode 1753: Doing Homework Algorithms Medium
    35、CSS进阶——行盒的垂直对齐以及图片底部白边
    1859. 将句子排序
    Powershell 7.x中UTF-8环境中文乱码解决办法
  • 原文地址:https://blog.csdn.net/2301_79561199/article/details/136396684
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号