pip install torchviz
python
import torch
from torchviz import make_dot
# 创建一个简单的模型
model = torch.nn.Sequential(
torch.nn.Linear(2, 2),
torch.nn.ReLU(),
torch.nn.Linear(2, 1)
)
# 创建输入数据
x = torch.randn(1, 2)
# 前向传播
y = model(x)
# 可视化计算图
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("computational_graph", format="png") # 保存为 PNG 图片
出现错误:提示没有dot命令,
解决方法:安装graphviz
