码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • torch搭建神经网络(三)


    目录

    .1构建数据点集

    .2搭建简单的神经网络

    2.1网络构造

    2.2网络推理

    2.2.1定义前向传播

    2.2.2定义激活函数

    2.2.3先是实例化网络

    2.2.4定义一个优化器

     2.2.5定义损失函数

     .3训练数据集分析结果

    3.1训练过程

    3.1.1定义训练次数

    3.1.2输出预测结果

    3.1.3 输出损失值

    3.1.4 优化器设置

    3.2 可视化训练过程


    .1构建数据点集

    在搭建神经网络之间我们需要明确数据集是什么,训练的目的是什么

    下图第一个数据集是手写数字,第二个是行驶的车辆,第三个是语义分割的数据集

    最后的希望的结果是 

    目标是看到y=x^ 2的图像

    制作的数据集是一系列含有噪声的而二维数据在y=x^ 2周围

    1. import torch
    2. import torch.nn.functional as f
    3. import matplotlib.pyplot as plt
    4. import matplotlib;matplotlib.use('TkAgg')
    5. x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
    6. y=x.pow(2)+0.2*torch.rand(x.size())
    7. plt.scatter(x.data.numpy(),y.data.numpy())
    8. plt.show()

     以下为加了噪声的散点图

    .2搭建简单的神经网络

    2.1网络构造

    有了数据集之后,就可以开始搭建网络

    可以直接用torch内置的模版

    class Net(torch.nn.Module):

    然后定义构造方法;

    输入,隐藏层,输出都内置为参数

    1. class Net(torch.nn.Module):
    2. def __init__(self,n_feature,n_hiddenn_out) -> None:

    然后继承一下父类

    1. class Net(torch.nn.Module):
    2. def __init__(self,n_feature,n_hiddenn_out) -> None:
    3. super().__init__()

    定义一下隐藏层:

     self.hidden=torch.nn.Linear(n_feature,n_hidden)

    再定义输出层:

    self.predict=torch.nn.Linear(n_hidden,n_output)

    至此就完成了简单的网络构造

    2.2网络推理

    2.2.1定义前向传播

     def forward(self,x):

    2.2.2定义激活函数

     x=f.relu(self.hidden(x))

    这样就有非线性特征

    然后将激活后的值传入预测层

     x=self.predict(x)

    返回预测值

     return x

    完整网络:

    1. class Net(torch.nn.Module):
    2. def __init__(self,n_feature,n_hidden,n_output) -> None:
    3. super().__init__()
    4. self.hidden=torch.nn.Linear(n_feature,n_hidden)
    5. self.predict=torch.nn.Linear(n_hidden,n_output)
    6. def forward(self,x):
    7. x=f.relu(self.hidden(x))
    8. x=self.predict(x)
    9. return x

    有了网络之后就可以对其进行训练:

    2.2.3先是实例化网络

    net = Net(1,10,1)

    一个输入,隐藏层10个神经元,1个输出

    然后可以打印net

    print(net)
    

    2.2.4定义一个优化器

    optimizer=torch.optim.SGD(net.parameters(),lr=0.2)

     2.2.5定义损失函数

    loss_func=torch.nn.MSELoss()

    初始参数是随机的,也可以打印

    print(net.parameters())

    完整编码:

    1. net = Net(1,10,1)
    2. # print(net)
    3. optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
    4. loss_func=torch.nn.MSELoss()
    5. # print(net.parameters())

     .3训练数据集分析结果

    3.1训练过程

    然后开始训练

    3.1.1定义训练次数

    训练500次

    for t in range(500):

    开始训练过程

    3.1.2输出预测结果

    让x依次传入网络乘以参数输出一个随机结果

        prediction=net.forward(x)

    3.1.3 输出损失值

        loss=loss_func(prediction,y)

    实际上就是预测值和真实值之间的差距

    对应x传入网络输出的值和真实的y值(和x一一对应的值)之间的损失

    3.1.4 优化器设置

    把优化器梯度归零

    optimizer.zero_grad()

    然后反向传播

        loss.backward()

    然后迭代模型更新

        optimizer.step()

    完整编码:

    1. for t in range(500):
    2. prediction=net.forward(x)
    3. loss=loss_func(prediction,y)
    4. optimizer.zero_grad()
    5. loss.backward()
    6. optimizer.step()

    3.2 可视化训练过程

    每五次显示一次训练结果

    if t%5==0:

    输出预测结果和原始数据之间的差距:

    1. plt.cla()
    2. plt.scatter(x.data.numpy(),y.data.numpy())
    3. plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

    打印损失值

    1. print('loss=:',loss.data)
    2. plt.pause(0.1)

    然后可视化

    1. plt.ioff()
    2. plt.show()

    运行效果图

     

     可以看到效果不断变好

    loss也是不断变小

     

  • 相关阅读:
    postman 发送post请求中的x-www-form-urlencoded和form-data的区别
    【嵌入式设计与实现】1 Keil MDKS TM32 CubeMX 的开发环境建立及Proteus仿真运行
    mac显示器如何显示docker container中的gui请求
    外包干了3个月,技术退步明显。。。。。
    【JavaSE】类和对象 (二) —— 封装、包以及 static 关键字
    数据湖:海量日志采集引擎Flume
    前端面试的话术集锦第 5 篇:高频考点( 类型转换 & 深浅拷贝 & 模块化机制等)
    ESP-IDF-V5.1.1使用websocket
    面试算法 二叉树的遍历,方法 :线索二叉树 ( morris ) ,前序遍历: 中序遍历: 后序遍历
    MySQL基础学习总结(四)
  • 原文地址:https://blog.csdn.net/weixin_50920579/article/details/126489932
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号