• TensorBoard——Pytorch版使用(附带案例演示)


    TensorBoard是一个用于可视化机器学习实验结果的工具,可以帮助我们更好地理解和调试训练过程中的模型。

    在PyTorch中,我们可以使用TensorBoardX库来与TensorBoard进行交互。TensorBoardX是一个PyTorch的扩展,它允许我们将PyTorch的训练中的关键指标和摘要写入TensorBoard的事件文件中。

    一、TensorBoard的使用步骤:

    下面是使用TensorBoardX进行可视化的一些常见步骤:

    1. 安装TensorBoard

    确保你已安装TensorBoard。对于PyTorch用户,TensorBoard也可以独立安装:

    pip install tensorboard
    

    2. 在你的代码中配置TensorBoard

    使用PyTorch时,你可以通过`torch.utils.tensorboard`模块来使用TensorBoard。首先,导入`SummaryWriter`来记录事件:

    1. from torch.utils.tensorboard import SummaryWriter
    2. # 初始化SummaryWriter
    3. writer = SummaryWriter('runs/experiment_name')

    然后,在你的训练循环中,使用`writer.add_scalar`等方法来记录你感兴趣的信息,例如损失和准确率:

    1. for epoch in range(num_epochs):
    2. # 训练模型...
    3. loss = ...
    4. accuracy = ...
    5. # 记录损失和准确率
    6. writer.add_scalar('Loss/train', loss, epoch)
    7. writer.add_scalar('Accuracy/train', accuracy, epoch)
    8. # 关闭writer
    9. writer.close()

    3. 在PyCharm中启动TensorBoard

    接下来,有两种方法在PyCharm中查看TensorBoard:

    方案一使用Terminal

    1. 打开PyCharm的Terminal。
    2. 导航到你的项目目录。
    3. 使用以下命令启动TensorBoard:

    tensorboard --logdir=runs/
    

    在这要注意一点:pycharm终端默认使用的是base环境,所以终端前面会显示PS,需要进入项目所使用的环境中才可执行tensorboard --logdir=runs/,具体如何操作点击这里。否则会出现如下报错:

    方案二:配置PyCharm运行配置

    1. 在PyCharm中,点击右上角的“Add Configuration”。
    2. 点击"+",选择"Python"。
    3. 在"Script path"中,找到并输入`tensorboard`的执行文件路径。
    4. 在"Parameters"字段中,输入`--logdir=runs/`,确保路径与你的TensorBoard日志目录匹配。
    5. 保存配置,然后你可以通过点击运行按钮来启动TensorBoard。

     4. 浏览TensorBoard

    在TensorBoard启动后,通过浏览器访问TensorBoard界面,你可以看到损失、准确率、图像示例等多种类型的日志信息,这些都可以帮助你分析和改进你的模型。

    小提示

    • 当使用PyTorch时,`SummaryWriter`的路径(例如`runs/experiment_name`)定义了TensorBoard日志的存储位置。确保每次实验使用不同的名称,以便在TensorBoard中清晰地区分它们。
    • 利用TensorBoard的高级特性,如图像、图表和直方图记录,可以提供更多关于模型训练过程和结果的洞察。

    二、案例演示

    步骤1: 创建PyTorch模型

    首先,我们定义一个简单的线性回归模型。

    1. import torch
    2. import torch.nn as nn
    3. import numpy as np
    4. from torch.utils.tensorboard import SummaryWriter
    5. # 定义模型
    6. class LinearRegressionModel(nn.Module):
    7. def __init__(self):
    8. super(LinearRegressionModel, self).__init__()
    9. self.linear = nn.Linear(1, 1) # 输入和输出都是1维
    10. def forward(self, x):
    11. return self.linear(x)

    步骤2: 训练模型并记录日志

    接着,我们将准备数据、定义损失函数和优化器,并在训练循环中使用SummaryWriter来记录损失:

    1. # 准备数据
    2. x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
    3. [9.779], [6.182], [7.59], [2.167],
    4. [7.042], [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
    5. y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
    6. [3.366], [2.596], [2.53], [1.221],
    7. [2.827], [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
    8. x_train = torch.from_numpy(x_train)
    9. y_train = torch.from_numpy(y_train)
    10. # 初始化模型
    11. model = LinearRegressionModel()
    12. # 损失和优化器
    13. criterion = nn.MSELoss()
    14. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    15. # 初始化SummaryWriter
    16. writer = SummaryWriter('runs/linear_regression_experiment')
    17. # 训练模型
    18. num_epochs = 100
    19. for epoch in range(num_epochs):
    20. # 转换为tensor
    21. inputs = x_train
    22. targets = y_train
    23. # 前向传播
    24. outputs = model(inputs)
    25. loss = criterion(outputs, targets)
    26. # 反向传播和优化
    27. optimizer.zero_grad()
    28. loss.backward()
    29. optimizer.step()
    30. # 记录损失
    31. writer.add_scalar('Loss/train', loss.item(), epoch)
    32. if (epoch+1) % 10 == 0:
    33. print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    34. # 关闭SummaryWriter
    35. writer.close()

    步骤3: 在PyCharm中启动TensorBoard

    (训练开始或完成之后)按照之前的指导,在PyCharm中使用Terminal或配置运行配置来启动TensorBoard。确保TensorBoard的--logdir参数设置为runs/,与SummaryWriter的初始化路径一致。

    这里采用Terminal的方案:

    步骤4: 观察TensorBoard

    点击生成的链接http://localhost:6006/即可查看结果:

  • 相关阅读:
    vue中实现文件批量打包压缩下载(以及下载跨域问题分析)
    kubelet节点压力驱逐
    2022-08-26 Unity视频播放4——全景视频
    2023版:深度比较几种.NET Excel导出库的性能差异
    PHP —— 一份前端开发工程师够用的PHP知识点(持续更新)
    RocketMq(二)-访问面板搭建及问题修复
    C#中的Web抓取:避免被阻挡
    Spark性能优化实战总结
    cap理论、base 定理、分布式事务的理解与相互关系
    C++编译链接详解
  • 原文地址:https://blog.csdn.net/m0_61878383/article/details/136552258