与tensorflow模型与caffe模型不同,当前的pytorch没有官方的直观查看网络结构的工具,google了下pytorch的网络解析的方法,发现可以将pytorch的model转换成为events文件使用tensorboard查看,记录之。

安装插件

  • TensorboardX,TensorboardX支持scalar, image, figure, histogram, audio, text, graph, onnx_graph, embedding, pr_curve and videosummaries等不同的可视化展示方式,具体介绍移步至项目Github 观看详情。使用下面的命令安装

    pip install tensorboardX
    
  • 安装tensorboard,参考命令

    pip install tensorboard
    

具体过程

参考代码

#-*-coding:utf-8-*-
import torch
import torchvision
from torch.autograd import Variable
from tensorboardX import SummaryWriter

# 模拟输入数据
input_data = Variable(torch.rand(16, 3, 224, 224))

# 从torchvision中导入已有模型
net = torchvision.models.resnet18()

# 声明writer对象,保存的文件夹
writer = SummaryWriter(log_dir='./log', comment='resnet18')
with writer:
    writer.add_graph(net, (input_data,))

该代码中14行声明一个writer对象,分别表示events存放的目录,comment表示事件的title,然后使用如下的方式打开tensorboard

tensorboard --logpath=D:\log --port=6006

然后按照命令行提示打开即可。

参考链接