• Pytorch框架学习记录2——TensorBoard的使用(一)


    Pytorch框架学习记录2——TensorBoard的使用

    Tensorboard在训练模型时很有用,可以看训练过程中loss的变化。之前用于Tensorflow框架,自Pytorch1.1之后,Pytorch也加了这个功能。

    1. TensorBoard的使用

    torch.utils.tensorboard中导入SummaryWriter类。

    1.1 SummaryWriter类使用

    SummaryWriter类可以在指定文件夹生成一个事件文件,这个事件文件可以对TensorBoard解析。

    首先实例化一个SummaryWriter的类,参数代表保存的文件夹的名称

    writer = SummaryWriter("logs")  #文件夹名称
    
    • 1

    1.2 writer.add_scalar()方法:

    def add_scalar(
        self,
        tag,
        scalar_value,
        global_step=None,
        walltime=None,
        new_style=False,
        double_precision=False,
    ):
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    上面是官方定义的参数,这个方法是添加标量的意思。

    tag:所画图标的title,str类型,注意引号

    scalar_value:需要保存的数值,对应y轴的y值

    global_step:全局的步数,对应x轴

    这里的tag起到唯一标识图像的作用,如果需要两幅图,则重新改写tag!!!

    1.4 举例

    绘制y = x 的图像。

    from torch.utils.tensorboard import SummaryWriter
    
    
    writer = SummaryWriter("logs")
    
    for i in range(100):
        writer.add_scalar("y = x", i, i)
    
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    1.5 add_image()方法

    def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
    '''
            Args:
                tag (string): Data identifier
                img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
                global_step (int): Global step value to record
                walltime (float): Optional override default walltime (time.time())
                  seconds after epoch of event
            Shape:
                img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
                convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
                Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
                corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    其中,img_tensor参数类型要求为:torch.Tensornumpy.array或者string类型。global_step为步骤,int类型。

    img_tensor默认的图片尺寸格式为(3,H,W),但是一般我们的图片格式为(H,W,3),因此需要对图片格式进行调整,调整的方法如下:设置dataformats='HWC'即可!!!

    from torch.utils.tensorboard import SummaryWriter
    from PIL import Image
    import numpy as np
    
    writer = SummaryWriter("logs")
    
    img_path = "C:\\Users\\hp\\PycharmProjects\pythonProject\\Pytorch_Learning\\flower_data\\train\daisy\\5547758_eea9edfd54_n.jpg"
    img_PIL = Image.open(img_path)
    img_array = np.array(img_PIL)
    print(type(img_array))
    writer.add_image("test", img_array, 1, dataformats='HWC')
    
    for i in range(100):
        writer.add_scalar("y = x", i, i)
    
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述

    1.6 事件文件打开方法:

    在PyCharm中,找到项目所在的terminal,

    tensorboard --logdir=事件文件所在的文件夹名 --port=端口号
    
    • 1

    在弹出的网址上打开即可。

    在这里插入图片描述
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yfqvOSZM-1658972434470)(C:\Users\hp\AppData\Roaming\Typora\typora-user-images\image-20220728092728557.png)]

    2. 常见的报错情况

    2.1 ImportError: TensorBoard logging requires TensorBoard version 1.15 or above

    tensorboard版本过低,卸载后重新安装最新版本。

    2.2 AttributeError: module ‘tensorflow._api.v1.io’ has no attribute ‘gfile’

    tensorflow 的版本和tensorboard的版本不匹配。升级tensorflow版本为2.0版本即可,执行下面命令会自动升级tensorflow版本和tensorboard版本:

    pip install tensorflow==2.0

    2.3 ModuleNotFoundError: No module named ‘grpc’

    重新安装grpcio

    pip uninstall grpcio

    pip install grpcio

    2.4 No dashboards are active for the current data set.

    打开的文件位置不存在,将文件夹位置的绝对路径复制,然后重新执行terminal语句即可

  • 相关阅读:
    ssh连接远程服务器,并在终端安装anaconda
    【分析笔记】全志平台 TWI 上拉电压异常的问题
    SSTI模板注入(flask) 学习总结
    java学习之包
    异步FIFO设计的仿真与综合技术(3)
    【Python】PySpark 数据计算 ① ( RDD#map 方法 | RDD#map 语法 | 传入普通函数 | 传入 lambda 匿名函数 | 链式调用 )
    uni-app运行到微信开发者工具-没有打印的情况
    Javascript V8引擎与Blob对象
    Android源码——Configuration源码解析
    分享一下微信小程序里的预约链接怎么做
  • 原文地址:https://blog.csdn.net/qq_45955883/article/details/126027999