profile
viewpoint

Ask questionstorch.onnx.export does not preserve weights name

🐛 Bug

I am exporting a onnx graph through torch.onnx.export and then visualise it through Netron. While the bias are correctly encoded in layersX.bias the weights' name are not preserved

To Reproduce

class MlpNaive(nn.Module):
    def __init__(self,input_dim,H,basis_fun_output=7):
        super(MlpNaive, self).__init__()
        self.input_dim = input_dim
        self.H = H
        self.basis_fun_output = basis_fun_output
        self.layer1 = torch.nn.Linear(self.input_dim, self.H)
        self.layer2 = torch.nn.ReLU()
        self.layer3 = nn.Dropout(0.3) 
        self.layer4 = torch.nn.Linear(self.H, self.H)
        self.layer5 = torch.nn.ReLU()
        self.layer6 = nn.Dropout(0.3)
        self.layer7 = torch.nn.Linear(self.H, self.H)
        self.layer8 = torch.nn.ReLU()
        self.layer9 = nn.Dropout(0.3) 
        self.layer10 = torch.nn.Linear(self.H, self.H)
        self.layer11 = torch.nn.ReLU()
        self.layer12 = nn.Dropout(0.3) 
        self.layersf = torch.nn.Linear(self.H, self.basis_fun_output)
        self.layersff = torch.nn.ReLU()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)
        x = self.layer9(x)
        x = self.layersf(x)
        x = self.layersff(x)
        return x
    
input_net = 500
        hidden = 50
        dummy_input = torch.ones(input_net)
        model = MlpNaive(input_net,hidden,len(basis_functions))
        path = "sandbox"
        torch.onnx.export(model, dummy_input, os.path.join(path,"MlpNaive.onnx"),verbose=True, opset_version=11, export_params=False, training=True, input_names=["input"], output_names=["output"])

Current output with verbose on

graph(%input.1 : Float(500),
      %layer1.bias : Float(50),
      %layer4.bias : Float(50),
      %layer7.bias : Float(50),
      %layersf.bias : Float(8),
      %33 : Float(500, 50),
      %34 : Float(50, 50),
      %35 : Float(50, 50),
      %36 : Float(50, 8)):
  %12 : Float(50) = onnx::MatMul(%input.1, %33) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1612:0
  %13 : Float(50) = onnx::Add(%12, %layer1.bias)
  %14 : Float(50) = onnx::Relu(%13) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1063:0
  %15 : Float(50), %16 : Tensor = onnx::Dropout[ratio=0.29999999999999999](%14) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:936:0
  %18 : Float(50) = onnx::MatMul(%15, %34) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1612:0
  %19 : Float(50) = onnx::Add(%18, %layer4.bias)
  %20 : Float(50) = onnx::Relu(%19) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1063:0
  %21 : Float(50), %22 : Tensor = onnx::Dropout[ratio=0.29999999999999999](%20) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:936:0
  %24 : Float(50) = onnx::MatMul(%21, %35) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1612:0
  %25 : Float(50) = onnx::Add(%24, %layer7.bias)
  %26 : Float(50) = onnx::Relu(%25) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1063:0
  %27 : Float(50), %28 : Tensor = onnx::Dropout[ratio=0.29999999999999999](%26) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:936:0
  %30 : Float(8) = onnx::MatMul(%27, %36) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1612:0
  %31 : Float(8) = onnx::Add(%30, %layersf.bias)
  %output : Float(8) = onnx::Relu(%31) # /Users/tommaso/Documents/Code/CSEM/New-EQ-learn/env/lib/python3.7/site-packages/torch/nn/functional.py:1063:0
  return (%output)

Expected behavior

I would like to have an output as:

graph(%input.1 : Float(500),
      %layer1.bias : Float(50),
      %layer4.bias : Float(50),
      %layer7.bias : Float(50),
      %layersf.bias : Float(8),
      %layer1.weight : Float(500, 50),
      %layer4.weight : Float(50, 50),
      %layer7.weight : Float(50, 50),
      %layer8.weight : Float(50, 8)):

Environment

  • PyTorch Version (e.g., 1.0): 1.5.1
  • OS (e.g., Linux): macOS Mojave 10.14.6
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: python 3.7.6

cc @houseroad @spandantiwari @lara-hdr @BowenBao @neginraoof

pytorch/pytorch

Answer questions TommasoBendinelli

Thank you for the answer! In this case I think the correspondence is 1:1 so a correspondence of the name is possible. I am using ONNX visualisation tool for learning neural network architectures, and a good correspondence makes things so much easier

useful!

Related questions

TensorBoard logging requires TensorBoard with Python summary writer installed. This should be available in 1.14 or above hot 3
AttributeError: module 'torch.jit' has no attribute 'unused' hot 3
Script freezes with no output when using DistributedDataParallel hot 2
Adding Pixel Unshuffle hot 2
DataLoader leaking Semaphores. hot 2
[feature request] Add matrix exponential hot 2
cublas runtime error on torch.bmm() with CUDA10 and RTX2080Ti hot 2
libtorch does not initialize OpenMP/MKL by default hot 2
Use torch.device() with torch.load(..., map_location=torch.device()) hot 2
Cuda required when loading a TorchScript with map_location='cpu' hot 2
PyTorch 1.5 failed to import c:miniconda3-x64envs estlibsite-packages orchlibcaffe2_nvrtc.dll - pytorch hot 2
Quantisation of object detection models. hot 2
Problems with install python from source hot 2
a retrained and saved jit module could not be reload. hot 2
from torch._C import * (ImportError: DLL load failed: The specified module could not be found. hot 2
source:https://uonfu.com/
Github User Rank List