在原本代码中额外添加如下几行即可实现查看模型结构:
- from tensorboardX import SummaryWriter # 用于进行可视化
-
- # 1. 来用tensorflow进行可视化
- with SummaryWriter("./log", comment="sample_model_visualization") as sw:
- sw.add_graph(modelviz, sampledata)
安装完torch之后,再安装tensorboardX
pip install tensorboardX -i https://pypi.tuna.tsinghua.edu.cn/simple
运行下面代码
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from tensorboardX import SummaryWriter # 用于进行可视化
-
- class modelViz(nn.Module):
- def __init__(self):
- super(modelViz, self).__init__()
- self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1)
- self.bn1 = nn.BatchNorm2d(16)
- self.conv2 = nn.Conv2d(16, 64, 3, 1, padding=1)
- self.bn2 = nn.BatchNorm2d(64)
- self.conv3 = nn.Conv2d(64, 10, 3, 1, padding=1)
- self.bn3 = nn.BatchNorm2d(10)
-
- def forward(self, x):
- x = self.bn1(self.conv1(x))
- x = F.relu(x)
- x = self.bn2(self.conv2(x))
- x = F.relu(x)
- x = self.bn3(self.conv3(x))
- x = F.relu(x)
- return x
-
-
- if __name__ == "__main__":
- # 首先来搭建一个模型
- modelviz = modelViz()
- # 创建输入
- sampledata = torch.rand(1, 3, 4, 4)
- # 看看输出结果对不对
- out = modelviz(sampledata)
- print(out) # 测试有输出,网络没有问题
-
- # 1. 来用tensorflow进行可视化
- with SummaryWriter("./log", comment="sample_model_visualization") as sw:
- sw.add_graph(modelviz, sampledata)
-
- # # 2. 保存成pt文件后进行可视化
- # torch.save(modelviz, "./log/modelviz.pt")
运行代码后会在"./log"路径下生成一个tfevents文件,在终端中进入代码的主目录下执行命令:
tensorboard --logdir=./
然后会输出
- (base) jie@dell:~/桌面/fno_task$ tensorboard --logdir=./
- TensorFlow installation not found - running with reduced feature set.
-
- NOTE: Using experimental fast data loading logic. To disable, pass
- "--load_fast=false" and report issues on GitHub. More details:
- https://github.com/tensorflow/tensorboard/issues/4784
-
- Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
- TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
然后按照提示打开浏览器,输入上面这个网址就可以看到我们搭建的网络结构了,如下图所示,可以双击打开每一个节点查看其内容。也可以查看详细的结构以及每一层的输入输出shape。通过双击模型的组件实现展示网络细节和收起细节。
结束!!!
官网详细和介绍使用链接:https://www.tensorflow.org/tensorboard/graphs?hl=zh-cn
tips:tensorboard是适用于tensorflow,而tensorboardX可以适用pytorch
tips: 如果你在虚拟环境cd到log的上一级文件夹,那么按照上面的路径就得不到你想要的可视化结果,路径不正确,应该输入
tensorboard --logdir=./log/
参考链接:https://blog.csdn.net/Vertira/article/details/127326470