Network visualization is very important. Looking at the network graph is much easier than reading through the network description. In this post, I will discuss how to create visual graph of a (Pytorch) network using Netron and include intermediate layer shapes as well.
Netron is an excellent tool for network visualization. At the moment, however, Netron does not support Pytorch natively (experimental feature but not stable). The best thing is to convert Pytorch model to ONNX and then use Netron to graph it.
The following is an example code that graphs ResNet50.
import torch
from torchvision import models
import onnx
from onnx import shape_inference
DEVICE = 'cuda:1'
PATH = 'resnet.onnx'
model = models.resnet50().to(DEVICE)
model.eval()
dummy_input = torch.randn(1,3,224,224).to(DEVICE)
torch.onnx.export(model, dummy_input, PATH, verbose=False)
onnx.save(shape_inference.infer_shapes(onnx.load(PATH)), PATH) # this is required for displaying intermediate shapes
To graph it, simply download Netron and run
$ netron resnet.onnx
That's it!
No comments:
Post a Comment