Skip to content

Commit fc25c68

Browse files
ydwu4pytorchmergebot
authored andcommitted
[hop][exc] make UncapturedHigherOrderOpError print user code and avoid re-raise (#159296)
After the change, the error stacktrace is attached with user code stack and is suppressed into 1 (without the scrolling up mssage). For example: ```python class Test(torch.nn.Module): def forward(self, c, x): def cond_fn(c, x): return c > 0 and x.size(0) < 20 def body_fn(c, x): return c - 1, x.sin() return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x)) ``` Now gives the following error message: ```python Traceback (most recent call last): File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion self._run_test( ~~~~~~~~~~~~~~^ model=WhileLoopModels.SizeMismatchTensorExpansion(), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ...<2 lines>... dynamic=dynamic, ^^^^^^^^^^^^^^^^ ) ^ File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test result = model(*inputs_with_counters) File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x)) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop return torch.compile( ~~~~~~~~~~~~~~ _while_loop_op_wrapper, backend=backend, fullgraph=True ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ )(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple()) ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__ result = self._torchdynamo_orig_backend( frame, cache_entry, self.hooks, frame_state, skip=1 ) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__ result = self._inner_convert( frame, cache_entry, hooks, frame_state, skip=skip + 1 ) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__ result = _compile( frame.f_code, ...<16 lines>... convert_frame_box=self._box, ) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function return function(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner out_code = transform_code_object(code, transform) File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object transformations(instructions, code_options) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform tracer.run() ~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run super().run() ~~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run while self.step(): ~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step self.dispatch_table[inst.opcode](self, inst) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper return inner_fn(self, inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error raise exc.with_traceback(sys.exc_info()[2]) from None File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error return fn(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function ) = speculate_subgraph( ~~~~~~~~~~~~~~~~~~^ tx, ^^^ ...<33 lines>... supports_aliasing=self.supports_aliasing, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph raise ex File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph output = f.call_function(tx, args, sub_kwargs) File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function return super().call_function(tx, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call return tracer.inline_call_() ~~~~~~~~~~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_ self.run() ~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run while self.step(): ~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step self.dispatch_table[inst.opcode](self, inst) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper return inner_fn(self, inst) File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX self.call_function(fn, argsvars.items, kwargsvars) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward return getattr(self.realize(), name)(*args, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function return super().call_function(tx, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call return tracer.inline_call_() ~~~~~~~~~~~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_ self.run() ~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run while self.step(): ~~~~~~~~~^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step self.dispatch_table[inst.opcode](self, inst) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^ File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner unimplemented_v2( ~~~~~~~~~~~~~~~~^ gb_type="Data-dependent branching", ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ...<5 lines>... ], ^^ ) ^ File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2 raise Unsupported(msg) torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. Developer debug context: attempted to jump with TensorVariable() For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html from user code: File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper return while_loop_op(*args, **kwargs) File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn return cond_fn(*carried, *additional) File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn return c > 0 and x.size(0) < 20 Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" To execute this test, run the following from the base repo dir: python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ``` Pull Request resolved: #159296 Approved by: https://github.com/zou3519
1 parent 5a40c57 commit fc25c68

File tree

3 files changed

+30
-30
lines changed

3 files changed

+30
-30
lines changed

test/higher_order_ops/test_invoke_subgraph.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,17 +1195,11 @@ def fn(x, y):
11951195
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
11961196

11971197
with self.assertRaisesRegex(
1198-
RuntimeError,
1199-
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
1200-
) as cm:
1198+
torch._dynamo.exc.UncapturedHigherOrderOpError,
1199+
"Encountered aliasing during higher order op tracing",
1200+
):
12011201
opt_fn(x, y)
12021202

1203-
cause = cm.exception.__cause__
1204-
self.assertIsInstance(cause, torch._dynamo.exc.Unsupported)
1205-
self.assertTrue(
1206-
"Encountered aliasing during higher order op tracing" in str(cause)
1207-
)
1208-
12091203
def test_input_input_aliasing(self):
12101204
@nested_compile_region
12111205
def gn(x, y):
@@ -1219,17 +1213,11 @@ def fn(x):
12191213
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
12201214

12211215
with self.assertRaisesRegex(
1222-
RuntimeError,
1223-
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
1224-
) as cm:
1216+
torch._dynamo.exc.UncapturedHigherOrderOpError,
1217+
"Encountered aliasing during higher order op tracing",
1218+
):
12251219
opt_fn(x)
12261220

1227-
cause = cm.exception.__cause__
1228-
self.assertIsInstance(cause, torch._dynamo.exc.Unsupported)
1229-
self.assertTrue(
1230-
"Encountered aliasing during higher order op tracing" in str(cause)
1231-
)
1232-
12331221
def test_output_output_aliasing(self):
12341222
@nested_compile_region
12351223
def gn(x):
@@ -1244,17 +1232,11 @@ def fn(x):
12441232
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
12451233

12461234
with self.assertRaisesRegex(
1247-
RuntimeError,
1248-
"torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph",
1249-
) as cm:
1235+
torch._dynamo.exc.UncapturedHigherOrderOpError,
1236+
"Encountered aliasing during higher order op tracing",
1237+
):
12501238
opt_fn(x)
12511239

1252-
cause = cm.exception.__cause__
1253-
self.assertIsInstance(cause, torch._dynamo.exc.Unsupported)
1254-
self.assertTrue(
1255-
"Encountered aliasing during higher order op tracing" in str(cause)
1256-
)
1257-
12581240
def test_mod_attr_aliasing(self):
12591241
class MutateParam(torch.nn.Module):
12601242
def __init__(self):

torch/_dynamo/exc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,14 @@ class UnsafeScriptObjectError(TorchDynamoException):
264264

265265

266266
class UncapturedHigherOrderOpError(TorchDynamoException):
267-
pass
267+
def __init__(self, msg: str, real_stack: Optional[StackSummary] = None) -> None:
268+
super().__init__(msg)
269+
self.msg = msg
270+
self.real_stack = (
271+
real_stack
272+
if real_stack is not None
273+
else torch._guards.TracingContext.extract_stack()
274+
)
268275

269276

270277
class IncorrectUsage(Exception):

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,19 @@ def graph_break_as_hard_error(*args, **kwargs):
7777
try:
7878
return fn(*args, **kwargs)
7979
except (Unsupported, ObservedException) as e:
80-
msg = " Scroll up to find out what causes the graph break."
81-
raise UncapturedHigherOrderOpError(reason + msg) from e
80+
import sys
81+
82+
if isinstance(e, Unsupported):
83+
exc = UncapturedHigherOrderOpError(
84+
f"{reason} Got {e.msg}", e.real_stack
85+
)
86+
else:
87+
msg = e.msg if hasattr(e, "msg") else type(e)
88+
real_stack = e.real_stack if hasattr(e, "real_stack") else None
89+
exc = UncapturedHigherOrderOpError(
90+
f"{reason} Got {msg}", real_stack
91+
)
92+
raise exc.with_traceback(sys.exc_info()[2]) from None
8293

8394
return graph_break_as_hard_error
8495

0 commit comments

Comments
 (0)