在搭建生成对抗网络时,我发现会用到一个将张量从计算图剥离的函数:detach(),这进一步激发了我的求知欲,我想知道这个detach函数到底是怎么执行的,我猜想把计算图可视化应该可以解决这个疑问,通过万能的互联网果真是有此类工具PyTorchViz,Github地址为:https://github.com/szagoruyko/pytorchviz,根据人家给出的Readme文件可知在进行代码编写前首先需要进行环境配置,安装人家需要的第三方库。
具体步骤如下:
1.在工作环境(比如通过Anconda创建的虚拟环境)下依次安装graphviz和torchviz:
1)安装graphviz
pip install graphviz
2)安装torchviz
pip install git+https://github.com/szagoruyko/pytorchviz
这是通过在线方式安装Github上的项目,也可以直接使用:
pip install torchviz
2.Windows系统下,还得单独安装软件graphviz,否则将会报错:
1)上官网:https://graphviz.org/download/,对应自己系统选择32/64位的exe文件;
2)双击exe进行安装;
在此过程中需要注意的就这一步,可选择是否自动添加环境变量,建议直接选第二个:用于所有用户
3)配置环境变量;
如果安装过程中没有选择自动添加环境变量,则在安装之后手动添加环境变量,即将GraphViz的bin目录路径添加到Path变量中。“我的电脑->属性->高级系统设置->环境变量->系统变量->Path->新建”,然后将该路径放入即可。
4)测试是否安装成功;
Win+R打开命令行窗口,输入以下命令看到如下结果证明安装成功。
dot -version
更多案例可查看官网示例文件:https://github.com/szagoruyko/pytorchviz/blob/master/examples.ipynb
此处我主要想知道卷积神经网络(CNN)的计算图长啥样。
import torch
import torch.nn as nn
from torchviz import make_dot,make_dot_from_trace
class DemoModel(nn.Module):
def __init__(self):
super(DemoModel, self).__init__()
self.model=nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,padding='same'),
nn.AvgPool2d(kernel_size=2), # (32,1,14,14)
nn.BatchNorm2d(num_features=8),
nn.ReLU(),
nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding='same'),
nn.ReLU(),
nn.Flatten(),
nn.Linear(in_features=14*14*16,out_features=10),
nn.Softmax()
)
def forward(self,x):
return self.model(x)
model=DemoModel()
inp_tensor=torch.ones(size=(1,1,28,28),requires_grad=True)
out=model(inp_tensor)
# 重点就是下面这两行代码
graph=make_dot(out,params=dict(model.named_parameters()),) # 生成计算图结构表示
# graph=make_dot(out,params=dict(model.named_parameters()),show_attrs=True,show_saved=True) # 生成计算图结构表示
# graph=make_dot(out) # 生成计算图结构表示
# graph=make_dot(out,dict(list(model.named_parameters())+[('x',inp_tensor)])) # 生成计算图结构表示
graph.render(filename='cnn',view=False,format='png') # 将源码写入文件,并对图结构进行渲染
# filename:默认生成文件名为filename+'.gv'.s
# view:表示是否使用默认软件打开生成的文件
# format:表示生成文件的格式,可为pdf、png等格式
执行此代码之后,可看到代码同目录下生成两个文件,一个是写入源码的文件,一个则是渲染过后的网络计算图。
网络计算图如下,可以对照网络结构查看:
对于make_dot函数当这两参数show_attrs=True, show_saved=True可同时记录用于反向传播的相关信息,此时计算图变为:
可以看到这俩参数带来的改变还是很大的!
此外,make_dot()函数在调用时可通过其他的方式生成计算图,比如:
(1)不使用模型参数
graph=make_dot(out) # 生成计算图结构表示
(2)集合模型参数和输入张量
graph=make_dot(out,dict(list(model.named_parameters())+[('x',inp_tensor)])) # 生成计算图结构表示