Tensorboard在训练模型时很有用,可以看训练过程中loss的变化。之前用于Tensorflow框架,自Pytorch1.1之后,Pytorch也加了这个功能。
从torch.utils.tensorboard
中导入SummaryWriter
类。
SummaryWriter类可以在指定文件夹生成一个事件文件,这个事件文件可以对TensorBoard解析。
首先实例化一个SummaryWriter
的类,参数代表保存的文件夹的名称
writer = SummaryWriter("logs") #文件夹名称
def add_scalar(
self,
tag,
scalar_value,
global_step=None,
walltime=None,
new_style=False,
double_precision=False,
):
上面是官方定义的参数,这个方法是添加标量的意思。
tag
:所画图标的title,str类型,注意引号
scalar_value
:需要保存的数值,对应y轴的y值
global_step
:全局的步数,对应x轴
这里的tag起到唯一标识图像的作用,如果需要两幅图,则重新改写tag!!!
绘制y = x 的图像。
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(100):
writer.add_scalar("y = x", i, i)
writer.close()
def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
'''
Args:
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
'''
其中,img_tensor
参数类型要求为:torch.Tensor
、numpy.array
或者string
类型。global_step
为步骤,int
类型。
img_tensor
默认的图片尺寸格式为(3,H,W),但是一般我们的图片格式为(H,W,3),因此需要对图片格式进行调整,调整的方法如下:设置dataformats='HWC'
即可!!!
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
writer = SummaryWriter("logs")
img_path = "C:\\Users\\hp\\PycharmProjects\pythonProject\\Pytorch_Learning\\flower_data\\train\daisy\\5547758_eea9edfd54_n.jpg"
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)
print(type(img_array))
writer.add_image("test", img_array, 1, dataformats='HWC')
for i in range(100):
writer.add_scalar("y = x", i, i)
writer.close()
在PyCharm中,找到项目所在的terminal,
tensorboard --logdir=事件文件所在的文件夹名 --port=端口号
在弹出的网址上打开即可。
tensorboard版本过低,卸载后重新安装最新版本。
tensorflow 的版本和tensorboard的版本不匹配。升级tensorflow版本为2.0版本即可,执行下面命令会自动升级tensorflow版本和tensorboard版本:
pip install tensorflow==2.0
重新安装grpcio
pip uninstall grpcio
pip install grpcio
打开的文件位置不存在,将文件夹位置的绝对路径复制,然后重新执行terminal语句即可