From d7a27651190db716338f731e9bf96c372719f16f Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 27 Jun 2025 16:45:39 -0300 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/dynamo/test_generator.py | 48 ++++++++- ...generators-ExceptionTest.test_except_throw | 0 ...onTest.test_except_throw_exception_context | 0 torch/_dynamo/output_graph.py | 1 + torch/_dynamo/symbolic_convert.py | 19 +++- torch/_dynamo/variables/functions.py | 101 +++--------------- 6 files changed, 80 insertions(+), 89 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw delete mode 100644 test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index adf1e5aff0d3..6c1ba1c4c9bc 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1345,8 +1345,10 @@ def fn(t): a = next(gen) try: gen.throw(ValueError) - except StopIteration: + except StopIteration as e: + assert len(e.args) == 0 return a + raise AssertionError("Expected StopIteration") t = torch.randn(2) y = self._compile_check(fn, (t,)) @@ -1480,6 +1482,50 @@ def fn(t): self._compile_check(fn) + def test_return_value_in_except_and_finally(self): + def whoo(): + try: + yield 1 + except ValueError: + return 2 # noqa: B901 + finally: + return 3 # noqa: B012, SIM107 + + def fn(t): + gen = whoo() + next(gen) + try: + gen.throw(ValueError) + except StopIteration as e: + assert e.args[0] == 3 + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + + def test_return_None_in_except_and_finally(self): + def whoo(): + try: + yield 1 + except ValueError: + return 2 # noqa: B901 + finally: + return # noqa: B012, SIM107 + + def fn(t): + gen = whoo() + next(gen) + try: + gen.throw(ValueError) + except StopIteration as e: + assert len(e.args) == 0 + except Exception as e: + raise AssertionError from e + return t.sin() + + self._compile_check(fn) + instantiate_parametrized_tests(GeneratorTests) instantiate_parametrized_tests(TestGeneratorSend) diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 4da00e4d76e0..0c6d39acfd1f 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1397,6 +1397,7 @@ def compile_subgraph( overridden_sources=overridden_sources, ) self.codegen_suffix(tx, stack_values_flat, pass1) + tx.close_local_generators() # Use `pass1.uses` to selectively cache multi-user variables into a # temporary local source. This (a). speeds up loading VTs with long diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 86697a07aa54..e197dfc365ce 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1006,7 +1006,7 @@ class ExceptionStack: # and "stack" sometimes refers to a C variable with the same name and the # exception stack, respectively. # - # The lifetime of an exception is (Python 3.11+): + # The lifetime of an exception in Python 3.11+ is: # + tx._raise_exception_variable(...) := sets the current_exception variable # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack* # + POP_EXCEPT := pops TOS from the *exception stack* @@ -1099,6 +1099,7 @@ class InstructionTranslatorBase( instruction_pointer: Optional[int] current_instruction: Instruction block_stack: list[BlockStackEntry] + local_generators: list[LocalGeneratorObjectVariable] lineno: int kw_names: Optional[ConstantVariable] accept_prefix_inst: bool @@ -1226,6 +1227,13 @@ def inline_user_function_return(self, fn, args, kwargs): else: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + def close_local_generators(self): + assert isinstance(self, InstructionTranslator) + with temporarely_allow_writes_to_output_graph(self): + for gen in self.local_generators: + if not gen._is_generator_exhausted(): + gen.call_method(self, "close", [], {}) + def get_line_of_code_header(self, lineno=None): if lineno is None: lineno = self.lineno @@ -3233,6 +3241,7 @@ def __init__( self.start_point = None self.current_instruction = create_instruction("NOP") self.block_stack = [] + self.local_generators: list[LocalGeneratorObjectVariable] = [] # states before SETUP_WITH for checkpointing and fallback self.active_generic_context_managers: list[GenericContextWrappingVariable] = [] self.lineno = -1 @@ -3972,7 +3981,13 @@ def inline_call_(self): ): assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration - exc.raise_observed_exception(StopIteration, self) + args = [] + if ( + isinstance(self.symbolic_result, ConstantVariable) + and self.symbolic_result.value is not None + ): + args = [self.symbolic_result] + exc.raise_observed_exception(StopIteration, self, args=args) else: return self.symbolic_result else: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index c98216aa2b40..c36ce1ad48b1 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -44,7 +44,6 @@ from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..exc import ( - get_dynamo_observed_exception, handle_observed_exception, InfiniteGeneratorError, ObservedException, @@ -582,6 +581,9 @@ def __init__( self.f_globals = f_globals self.inline_tracer = inline_tracer + root_tx = inline_tracer.output.root_tx + root_tx.local_generators.append(self) + def get_code(self): return self.code @@ -685,7 +687,8 @@ def force_apply_to_var_sequence(self, tx, fn) -> None: handle_observed_exception(tx) break - def _setup_exception(self, tx, exc): + def _setup_and_raise_exception(self, tx, exc): + # Raise an exception at the point where the generator is paused tracer = self._get_inline_tracer(tx) try: tracer._raise_exception_variable(exc) @@ -749,7 +752,7 @@ def call_method( # Raise GeneratorExit to see if user code catches it. Any other exception # is propagated to the parent frame. try: - self._setup_exception( + self._setup_and_raise_exception( tx, variables.ExceptionVariable(GeneratorExit, ()) ) # There's an extra block on Python 3.12+ to handle StopIteration @@ -798,93 +801,19 @@ def call_method( # returns the next value yielded by the generator. # * If the generator exits without yielding, raise StopIteration # * If the generator function does not catch the passed-in exception, - # or raises a different exception, then that exception propagates to the caller. + # or raises a different exception, then that new exception propagates to the caller. # Setup the exception table and jump target in case of try...finally tracer = self._get_inline_tracer(tx) - try: - # In Python 3.9, the exception is represented as a triple (typ, val, tb) - # In such cases, we re-raise the exception object given to avoid - # creating a new object, so that IS_OP works. - # See: https://github.com/pytorch/pytorch/pull/146496 - self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) - except ObservedException: # noqa: TRY203 - # propagate the exception back to the parent caller - raise - - retval = self.next_variable(tx) - - # The exception raised before is still active. We need to check the exception - # table one more time to find the next target. But why? Let’s walk - # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M - # - # z = 0 - # def whoo(): - # global z - # z = 0 - # try: - # yield 1 - # except ValueError: - # yield 2 - # finally: - # z += 1 - # z += 10 - # - # gen = whoo() - # next(gen) - # gen.throw(ValueError) - # print('z', z) -> z = 1 - # - # ... - # >> 58 PUSH_EXC_INFO - # - # 8 60 LOAD_GLOBAL 2 (ValueError) - # 70 CHECK_EXC_MATCH - # 72 POP_JUMP_IF_FALSE 7 (to 88) - # 74 POP_TOP - # - # 9 76 LOAD_CONST 3 (2) - # 78 YIELD_VALUE 3 <------ ValueError is still active here - # 80 RESUME 1 - # 82 POP_TOP - # 84 POP_EXCEPT - # 86 jump_backward 34 (to 20) - # ... - # - # ExceptionTable: - # 4 to 8 -> 124 [0] lasti - # 12 to 18 -> 58 [0] - # 20 to 56 -> 124 [0] lasti - # 58 to 82 -> 90 [1] lasti <------ move to 90 - # 84 to 86 -> 96 [0] - # 88 to 88 -> 90 [1] lasti - # 90 to 94 -> 96 [0] - # 96 to 116 -> 118 [1] lasti - # 118 to 122 -> 124 [0] lasti - # - # In this scenario, a generator can yield after `throw()` is called. Even - # after the exception is raised a few lines above, it remains active - # within the `78 YIELD_VALUE` instruction. When the generator resumes - # after the second yield on instruction `80 RESUME`, we cannot simply - # return the control flow to the next instruction. Instead, one must - # check the exception table (or equivalent) to find the next target - # In this case, it says the instruction pointer must be moved to 90. - # - # Without this step, if we let the trace proceed to the next - # instruction, it would follow the control flow where the exception - # raised by `throw()` was handled and swallowed, potentially leading - # to incorrect behavior. - exc_type = type("__InternalThrowException", (Exception,), {}) - try: - self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) - self.next_variable(tx) - except get_dynamo_observed_exception(exc_type): - # We should get back the exception raised before. - pass - else: - raise_observed_exception(RuntimeError, tracer) - return retval + # In Python 3.9, the exception is represented as a triple (typ, val, tb) + # In such cases, we raise the given object instead of creating a new + # one, so that IS_OP works. + # See: https://github.com/pytorch/pytorch/pull/146496 + self._setup_and_raise_exception(tx, args[1] if len(args) == 3 else args[0]) + + # If reaches here, it means user code captured the exception + return self.next_variable(tx) super().call_method(tx, name, args, kwargs) From 4d0eeb2103ed10c377b3d50c848665d950a2b0c7 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 27 Jun 2025 17:04:29 -0300 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- test/dynamo/test_generator.py | 44 ------------------------------- torch/_dynamo/symbolic_convert.py | 8 +----- 2 files changed, 1 insertion(+), 51 deletions(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 6c1ba1c4c9bc..7af9c91a1001 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1482,50 +1482,6 @@ def fn(t): self._compile_check(fn) - def test_return_value_in_except_and_finally(self): - def whoo(): - try: - yield 1 - except ValueError: - return 2 # noqa: B901 - finally: - return 3 # noqa: B012, SIM107 - - def fn(t): - gen = whoo() - next(gen) - try: - gen.throw(ValueError) - except StopIteration as e: - assert e.args[0] == 3 - except Exception as e: - raise AssertionError from e - return t.sin() - - self._compile_check(fn) - - def test_return_None_in_except_and_finally(self): - def whoo(): - try: - yield 1 - except ValueError: - return 2 # noqa: B901 - finally: - return # noqa: B012, SIM107 - - def fn(t): - gen = whoo() - next(gen) - try: - gen.throw(ValueError) - except StopIteration as e: - assert len(e.args) == 0 - except Exception as e: - raise AssertionError from e - return t.sin() - - self._compile_check(fn) - instantiate_parametrized_tests(GeneratorTests) instantiate_parametrized_tests(TestGeneratorSend) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index e197dfc365ce..ba662abfac58 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3981,13 +3981,7 @@ def inline_call_(self): ): assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration - args = [] - if ( - isinstance(self.symbolic_result, ConstantVariable) - and self.symbolic_result.value is not None - ): - args = [self.symbolic_result] - exc.raise_observed_exception(StopIteration, self, args=args) + exc.raise_observed_exception(StopIteration, self) else: return self.symbolic_result else: From 970da036157718c1219ed4edc0330f64f922a99d Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 27 Jun 2025 18:24:16 -0300 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- test/dynamo/test_generator.py | 4 +--- torch/_dynamo/output_graph.py | 2 +- torch/_dynamo/side_effects.py | 18 ++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 9 --------- torch/_dynamo/variables/functions.py | 16 +++++++++------- 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 7af9c91a1001..0de71eaf1836 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -974,7 +974,7 @@ def test_close_with_side_effects(self): z = 0 def whoo(t): - nonlocal z + nonlocal z # noqa: F824 try: L.append(1) yield t.sin() @@ -1015,7 +1015,6 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=True) def fn(t): - nonlocal z gen = whoo(t) i = next(gen) y = gen.close() @@ -1043,7 +1042,6 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=fullgraph) def fn(t): - nonlocal z gen = whoo(t) i = next(gen) gen.close() diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 0c6d39acfd1f..3fa79cb8737d 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1397,7 +1397,7 @@ def compile_subgraph( overridden_sources=overridden_sources, ) self.codegen_suffix(tx, stack_values_flat, pass1) - tx.close_local_generators() + self.side_effects.close_local_generators() # Use `pass1.uses` to selectively cache multi-user variables into a # temporary local source. This (a). speeds up loading VTs with long diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 418c7ec7b685..3fa1a56a8e12 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -117,6 +117,7 @@ def __init__( self.keepalive = keepalive or [] self.save_for_backward = save_for_backward or [] self.tensor_hooks = tensor_hooks or {} + self.local_generators = [] # Used by MappingProxyVariable to graph break in case of any mutated # dict self._has_existing_dict_mutation = False @@ -190,6 +191,23 @@ def should_allow_externally_visible_side_effects_in_subtracer(self): and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects ) + def track_generator(self, gen): + self.local_generators.append(gen) + + def untrack_generator(self, gen): + self.local_generators.remove(gen) + + def close_local_generators(self): + from .symbolic_convert import temporarely_allow_writes_to_output_graph + + output_graph = self.output_graph_weakref() + if output_graph: + tx = output_graph.root_tx + with temporarely_allow_writes_to_output_graph(tx): + for gen in self.local_generators: + if not gen.is_generator_exhausted(): + gen.call_method(tx, "close", [], {}) + def is_reconstructing_generator(self): output_graph = self.output_graph_weakref() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index ba662abfac58..b72b2f2f64eb 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1099,7 +1099,6 @@ class InstructionTranslatorBase( instruction_pointer: Optional[int] current_instruction: Instruction block_stack: list[BlockStackEntry] - local_generators: list[LocalGeneratorObjectVariable] lineno: int kw_names: Optional[ConstantVariable] accept_prefix_inst: bool @@ -1227,13 +1226,6 @@ def inline_user_function_return(self, fn, args, kwargs): else: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) - def close_local_generators(self): - assert isinstance(self, InstructionTranslator) - with temporarely_allow_writes_to_output_graph(self): - for gen in self.local_generators: - if not gen._is_generator_exhausted(): - gen.call_method(self, "close", [], {}) - def get_line_of_code_header(self, lineno=None): if lineno is None: lineno = self.lineno @@ -3241,7 +3233,6 @@ def __init__( self.start_point = None self.current_instruction = create_instruction("NOP") self.block_stack = [] - self.local_generators: list[LocalGeneratorObjectVariable] = [] # states before SETUP_WITH for checkpointing and fallback self.active_generic_context_managers: list[GenericContextWrappingVariable] = [] self.lineno = -1 diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index c36ce1ad48b1..c20934e487f9 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -581,8 +581,7 @@ def __init__( self.f_globals = f_globals self.inline_tracer = inline_tracer - root_tx = inline_tracer.output.root_tx - root_tx.local_generators.append(self) + inline_tracer.output.side_effects.track_generator(self) def get_code(self): return self.code @@ -647,7 +646,7 @@ def _get_inline_tracer(self, tx): def next_variable(self, tx): tracer = self._get_inline_tracer(tx) - if self._is_generator_exhausted(): + if self.is_generator_exhausted(): raise_observed_exception(StopIteration, tx) try: @@ -656,9 +655,12 @@ def next_variable(self, tx): # for Dynamo to behave correctly with patch.dict(counters, {"unimplemented": counters["inline_call"]}): return tracer.inline_call_() - except ObservedException as e: + except ObservedUserStopIteration: + tracer.output.side_effects.untrack_generator(self) + raise + except ObservedException: tracer.generator_exhausted = True - raise e + raise except InfiniteGeneratorError: # test/dynamo/test_misc.py::test_iterator_limit raise @@ -700,7 +702,7 @@ def _setup_and_raise_exception(self, tx, exc): def _is_generator_just_started(self): return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 - def _is_generator_exhausted(self): + def is_generator_exhausted(self): return getattr(self.inline_tracer, "generator_exhausted", False) def call_method( @@ -745,7 +747,7 @@ def call_method( # See test GeneratorCloseCpythonTests::test_close_not_started tracer = self._get_inline_tracer(tx) - if self._is_generator_just_started() or self._is_generator_exhausted(): + if self._is_generator_just_started() or self.is_generator_exhausted(): tracer.generator_exhausted = True return variables.ConstantVariable(None)