Skip to content

Commit 671a9d1

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Add warning for module full backward hook when no input requires gradient (#155339)
Pull Request resolved: #155339 Approved by: https://github.com/Skylion007
1 parent e25ce0f commit 671a9d1

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

test/nn/test_module_hooks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,14 @@ def hook(mod, grad_input, grad_output):
14451445
mod.register_full_backward_hook(hook)
14461446

14471447
# This should run and trigger the hook properly
1448-
mod(inp).sum().backward()
1448+
with self.assertWarnsRegex(
1449+
UserWarning,
1450+
(
1451+
"Full backward hook is firing when gradients are computed with "
1452+
"respect to module outputs since no inputs require gradients"
1453+
),
1454+
):
1455+
mod(inp).sum().backward()
14491456
self.assertEqual(hook_called[0], 1)
14501457

14511458
return_val = "grad_input"

torch/utils/hooks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,11 @@ def hook(_, grad_output):
223223
# Special case if no input required gradients, this hook should call the user
224224
# hook directly
225225
if self.input_tensors_index is None:
226+
warnings.warn("Full backward hook is firing when gradients are computed "
227+
"with respect to module outputs since no inputs require gradients. See "
228+
"https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950
229+
"for more details.",
230+
stacklevel=5)
226231
grad_inputs = self._pack_with_none([], [], self.n_inputs)
227232
for user_hook in self.user_hooks:
228233
res = user_hook(self.module, grad_inputs, self.grad_outputs)

0 commit comments

Comments
 (0)