Skip to content

Poor error message when trying to jit a function instead of a module (RuntimeError: Cannot insert a Tensor that requires grad as a constant.) #55282

@oliver-batchelor

Description

@oliver-batchelor

🐛 Bug

Confusing error when using a function which runs a model.

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

Better would be to recommend using a nn.Module.

To Reproduce

  import torch.nn as nn
  import torch

  model = nn.Conv2d(5, 10, (3, 3))

  def run(batch):
    return model(batch)

  traced = torch.jit.trace(run, torch.zeros(1, 5, 64, 64))

Expected behavior

Better error message for how to fix the problem.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

Collecting environment information...
PyTorch version: 1.8.1
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 9.0.0 (https://github.com/conda-forge/clangdev-feedstock 284a3d5d88509307bcfba64b055653ee347371db)
CMake version: version 3.16.0

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration: GPU 0: GeForce RTX 2070
Nvidia driver version: 455.32.00
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] numpy-quaternion==2020.9.5.14.42.2
[pip3] pytorch-tools==0.1
[pip3] torch==1.8.1
[pip3] torch2trt==0.1.0
[pip3] torchaudio==0.8.0a0+e4e171a
[pip3] torchvision==0.2.2
[pip3] trtorch==0.3.0a0
[conda] Could not collect


cc @gmagogsfm

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions