Ask questionstorch.utils.tensorboard.SummaryWriter.add_graph do not support non-tensor inputs

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

1.Run my script below:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
# from tensorboardX import SummaryWriter

# bug 1: bool type inputs
class Net_1(nn.Module):
    def __init__(self, dropout=0.5):
        super(Net_1, self).__init__()
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, use_dropout=False):
        x = F.relu(self.fc1(x))
        if use_dropout:
            x = self.dropout(x)  # or other operations ....
        x = F.relu(self.fc2(x))
        return x

with SummaryWriter("bugs") as w:
    net = Net_1()
    input_x = torch.randn((2,120))
    w.add_graph(net, (input_x, True))

# bug 2: None type inputs (might be argument's default value)
class Net_2(nn.Module):
    def __init__(self):
        super(Net_2, self).__init__()
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(120, 84)
        self.fc4 = nn.Linear(84, 10)

    def forward(self, x, y=None, z=None):
        x = F.relu(self.fc1(x))
        if y is not None:
            y = F.relu(self.fc2(y))
            x = x + y
        if z is not None:
            z = F.relu(self.fc3(z))
            x = x + z
        x = F.relu(self.fc4(x))
        return x

with SummaryWriter("bugs") as w:
    net = Net_2()
    input_x = torch.randn((2,120))
    input_y = None
    input_z = torch.randn((2,120))
    w.add_graph(net, (input_x, input_y, input_z))

# bug 3: List type inputs (dict, or other python build-in types like int,str,... may also meet this question)
class Net_3(nn.Module):
    def __init__(self):
        super(Net_3, self).__init__()
        self.fc_list = [nn.Linear(120, 120) for _ in range(10)]
        self.fc_n = nn.Linear(120, 10)

    def forward(self, x, index:list=None):
        if index is not None:
            for i in index:
                x = F.relu(self.fc_list[i](x))
        x = F.relu(self.fc_n(x))
        return x

with SummaryWriter("bugs") as w:
    net = Net_3()
    input_x = torch.randn((2, 120))
    index = [1, 5, 1, 7, 0]
    w.add_graph(net, (input_x, index))

and you can see the trace(take bug 3 as an example):

Error occurs, No graph saved
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Applications/", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/wangyuanzheng/Downloads/xxxxxxx/project/albert_pytorch/dev/", line 25, in <module>
    w.add_graph(net, (input_x, True))
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/", line 682, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose))
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/", line 239, in graph
    raise e
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/utils/tensorboard/", line 234, in graph
    trace = torch.jit.trace(model, args)
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/jit/", line 858, in trace
    check_tolerance, _force_outplace, _module_class)
  File "/Users/wangyuanzheng/anaconda3/envs/CCFBigData-torch/lib/python3.7/site-packages/torch/jit/", line 997, in trace_module
    module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
RuntimeError: Type 'Tuple[Tensor, bool]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced (toTraceableIValue at ../torch/csrc/jit/pybind_utils.h:298)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x110c479e7 in libc10.dylib)
frame #1: torch::jit::toTraceableIValue(pybind11::handle) + 1280 (0x110246740 in libtorch_python.dylib)
frame #2: torch::jit::toTypedStack(pybind11::tuple const&) + 31 (0x1102e7edf in libtorch_python.dylib)
frame #3: void pybind11::cpp_function::initialize<torch::jit::script::initJitScriptBindings(_object*)::$_16, void, torch::jit::script::Module&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, pybind11::function, pybind11::tuple, pybind11::function, bool, pybind11::name, pybind11::is_method, pybind11::sibling>(torch::jit::script::initJitScriptBindings(_object*)::$_16&&, void (*)(torch::jit::script::Module&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, pybind11::function, pybind11::tuple, pybind11::function, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) + 147 (0x11031e4e3 in libtorch_python.dylib)
frame #4: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3372 (0x10fe57d3c in libtorch_python.dylib)
<omitting python frames>

Expected behavior

<!-- A clear and concise description of what you expected to happen. --> writer.add_graph should run normally.


Collecting environment information... PyTorch version: 1.3.0 Is debug build: No CUDA used to build PyTorch: None OS: Mac OSX 10.14.6 GCC version: Could not collect CMake version: Could not collect Python version: 3.7 Is CUDA available: No CUDA runtime version: No CUDA GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Versions of relevant libraries: [pip] numpy==1.17.2 [pip] torch==1.3.0 [pip] torchvision==0.4.1 [conda] torch 1.3.0 pypi_0 pypi [conda] torchvision 0.4.1 pypi_0 pypi

Additional context

<!-- Add any other context about the problem here. --> 1.TensorboardX.SummaryWriter.add_graph has the same bug as torch.utils.tensorboard 2.Besides this bug, I hope add_graph could accept not only a tuple as positional arguments, but also a dict as keyword arguments for the model.forward()'s input


Answer questions lanpa

addressed here already:

Once JIT team supports this, it's easy to visualize it.

