diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 9d7318105c90..daa784d0b548 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -1009,7 +1009,7 @@ def test_close_with_side_effects(self): z = 0 def whoo(t): - nonlocal z + nonlocal z # noqa: F824 try: L.append(1) yield t.sin() @@ -1050,7 +1050,6 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=True) def fn(t): - nonlocal z gen = whoo(t) i = next(gen) y = gen.close() @@ -1078,7 +1077,6 @@ def whoo(t): @torch.compile(backend="eager", fullgraph=fullgraph) def fn(t): - nonlocal z gen = whoo(t) i = next(gen) gen.close() @@ -1380,8 +1378,10 @@ def fn(t): a = next(gen) try: gen.throw(ValueError) - except StopIteration: + except StopIteration as e: + assert len(e.args) == 0 return a + raise AssertionError("Expected StopIteration") t = torch.randn(2) y = self._compile_check(fn, (t,)) diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw b/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking3 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw rename to test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_generator_leaking3 diff --git a/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context b/test/dynamo_expected_failures/CPython313-test_generators-ExceptionTest.test_except_throw_exception_context deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index caa7b6fef530..73295fc95cec 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1431,6 +1431,11 @@ def compile_subgraph( ) self.codegen_suffix(tx, stack_values_flat, pass1) + # Close all generators opened while tracing. Needs to be done after + # pass1, as PyCodegen might try to reconstruct the generator, which + # sets LocalGeneratorObjectVariable.remaining_items + self.side_effects.close_local_generators() + # Use `pass1.uses` to selectively cache multi-user variables into a # temporary local source. This (a). speeds up loading VTs with long # chained source, and (b). avoids redundantly saving single-user VT diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 58ed0da5fb2d..c75f65665bc7 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -59,6 +59,7 @@ if TYPE_CHECKING: from torch._dynamo.output_graph import OutputGraph from torch._dynamo.symbolic_convert import InstructionTranslatorBase + from torch._dynamo.variables.functions import LocalGeneratorFunctionVariable from torch._dynamo.variables.lists import ListVariable @@ -134,6 +135,7 @@ def __init__( self.keepalive = keepalive or [] self.save_for_backward = save_for_backward or [] self.tensor_hooks = tensor_hooks or {} + self.local_generators: list[LocalGeneratorFunctionVariable] = [] # Used by MappingProxyVariable to graph break in case of any mutated # dict self._has_existing_dict_mutation = False @@ -228,6 +230,23 @@ def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool: and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects ) + def track_generator(self, gen: "LocalGeneratorFunctionVariable") -> None: + self.local_generators.append(gen) + + def untrack_generator(self, gen: "LocalGeneratorFunctionVariable") -> None: + self.local_generators.remove(gen) + + def close_local_generators(self) -> None: + from .symbolic_convert import temporarily_allow_writes_to_output_graph + + output_graph = self.output_graph_weakref() + if output_graph: + tx = output_graph.root_tx + with temporarily_allow_writes_to_output_graph(tx): + for gen in self.local_generators: + if not gen.is_generator_exhausted(): + gen.call_method(tx, "close", [], {}) + def is_reconstructing_generator(self) -> bool: output_graph = self.output_graph_weakref() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 8e5a1ef80393..21e316661c40 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -375,7 +375,7 @@ def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"): @contextlib.contextmanager -def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"): +def temporarily_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"): try: tmp = tx.output.should_exit tx.output.should_exit = False @@ -1019,7 +1019,7 @@ class ExceptionStack: # and "stack" sometimes refers to a C variable with the same name and the # exception stack, respectively. # - # The lifetime of an exception is (Python 3.11+): + # The lifetime of an exception in Python 3.11+ is: # + tx._raise_exception_variable(...) := sets the current_exception variable # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack* # + POP_EXCEPT := pops TOS from the *exception stack* diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 4bdcecf3b3c2..63f139d3404f 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -44,7 +44,6 @@ from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..exc import ( - get_dynamo_observed_exception, handle_observed_exception, InfiniteGeneratorError, ObservedException, @@ -639,6 +638,8 @@ def __init__( self.f_globals = f_globals self.inline_tracer = inline_tracer + inline_tracer.output.side_effects.track_generator(self) + def get_code(self): return self.code @@ -667,13 +668,13 @@ def reconstruct(self, codegen: "PyCodegen"): from torch._dynamo.symbolic_convert import ( InstructionTranslator, save_and_restart_speculation_log, - temporarely_allow_writes_to_output_graph, + temporarily_allow_writes_to_output_graph, ) tx = InstructionTranslator.current_tx() save = save_and_restart_speculation_log(tx) disallow = disallow_side_effects_in_generator(tx) - temp = temporarely_allow_writes_to_output_graph(tx) + temp = temporarily_allow_writes_to_output_graph(tx) with save, disallow, temp: tracer = self._get_inline_tracer(tx) @@ -702,7 +703,7 @@ def _get_inline_tracer(self, tx): def next_variable(self, tx): tracer = self._get_inline_tracer(tx) - if self._is_generator_exhausted(): + if self.is_generator_exhausted(): raise_observed_exception(StopIteration, tx) try: @@ -711,9 +712,12 @@ def next_variable(self, tx): # for Dynamo to behave correctly with patch.dict(counters, {"unimplemented": counters["inline_call"]}): return tracer.inline_call_() - except ObservedException as e: + except ObservedUserStopIteration: + tracer.output.side_effects.untrack_generator(self) + raise + except ObservedException: tracer.generator_exhausted = True - raise e + raise except InfiniteGeneratorError: # test/dynamo/test_misc.py::test_iterator_limit raise @@ -747,7 +751,8 @@ def force_apply_to_var_sequence(self, tx, fn) -> None: handle_observed_exception(tx) break - def _setup_exception(self, tx, exc): + def _setup_and_raise_exception(self, tx, exc): + # Raise an exception at the point where the generator is paused tracer = self._get_inline_tracer(tx) try: tracer._raise_exception_variable(exc) @@ -759,7 +764,7 @@ def _setup_exception(self, tx, exc): def _is_generator_just_started(self): return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 - def _is_generator_exhausted(self): + def is_generator_exhausted(self): return getattr(self.inline_tracer, "generator_exhausted", False) def call_method( @@ -804,14 +809,14 @@ def call_method( # See test GeneratorCloseCpythonTests::test_close_not_started tracer = self._get_inline_tracer(tx) - if self._is_generator_just_started() or self._is_generator_exhausted(): + if self._is_generator_just_started() or self.is_generator_exhausted(): tracer.generator_exhausted = True return variables.ConstantVariable(None) # Raise GeneratorExit to see if user code catches it. Any other exception # is propagated to the parent frame. try: - self._setup_exception( + self._setup_and_raise_exception( tx, variables.ExceptionVariable(GeneratorExit, ()) ) # There's an extra block on Python 3.12+ to handle StopIteration @@ -860,93 +865,19 @@ def call_method( # returns the next value yielded by the generator. # * If the generator exits without yielding, raise StopIteration # * If the generator function does not catch the passed-in exception, - # or raises a different exception, then that exception propagates to the caller. + # or raises a different exception, then that new exception propagates to the caller. # Setup the exception table and jump target in case of try...finally tracer = self._get_inline_tracer(tx) - try: - # In Python 3.9, the exception is represented as a triple (typ, val, tb) - # In such cases, we re-raise the exception object given to avoid - # creating a new object, so that IS_OP works. - # See: https://github.com/pytorch/pytorch/pull/146496 - self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) - except ObservedException: # noqa: TRY203 - # propagate the exception back to the parent caller - raise - - retval = self.next_variable(tx) - - # The exception raised before is still active. We need to check the exception - # table one more time to find the next target. But why? Let’s walk - # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M - # - # z = 0 - # def whoo(): - # global z - # z = 0 - # try: - # yield 1 - # except ValueError: - # yield 2 - # finally: - # z += 1 - # z += 10 - # - # gen = whoo() - # next(gen) - # gen.throw(ValueError) - # print('z', z) -> z = 1 - # - # ... - # >> 58 PUSH_EXC_INFO - # - # 8 60 LOAD_GLOBAL 2 (ValueError) - # 70 CHECK_EXC_MATCH - # 72 POP_JUMP_IF_FALSE 7 (to 88) - # 74 POP_TOP - # - # 9 76 LOAD_CONST 3 (2) - # 78 YIELD_VALUE 3 <------ ValueError is still active here - # 80 RESUME 1 - # 82 POP_TOP - # 84 POP_EXCEPT - # 86 jump_backward 34 (to 20) - # ... - # - # ExceptionTable: - # 4 to 8 -> 124 [0] lasti - # 12 to 18 -> 58 [0] - # 20 to 56 -> 124 [0] lasti - # 58 to 82 -> 90 [1] lasti <------ move to 90 - # 84 to 86 -> 96 [0] - # 88 to 88 -> 90 [1] lasti - # 90 to 94 -> 96 [0] - # 96 to 116 -> 118 [1] lasti - # 118 to 122 -> 124 [0] lasti - # - # In this scenario, a generator can yield after `throw()` is called. Even - # after the exception is raised a few lines above, it remains active - # within the `78 YIELD_VALUE` instruction. When the generator resumes - # after the second yield on instruction `80 RESUME`, we cannot simply - # return the control flow to the next instruction. Instead, one must - # check the exception table (or equivalent) to find the next target - # In this case, it says the instruction pointer must be moved to 90. - # - # Without this step, if we let the trace proceed to the next - # instruction, it would follow the control flow where the exception - # raised by `throw()` was handled and swallowed, potentially leading - # to incorrect behavior. - exc_type = type("__InternalThrowException", (Exception,), {}) - try: - self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) - self.next_variable(tx) - except get_dynamo_observed_exception(exc_type): - # We should get back the exception raised before. - pass - else: - raise_observed_exception(RuntimeError, tracer) - return retval + # In Python 3.9, the exception is represented as a triple (typ, val, tb) + # In such cases, we raise the given object instead of creating a new + # one, so that IS_OP works. + # See: https://github.com/pytorch/pytorch/pull/146496 + self._setup_and_raise_exception(tx, args[1] if len(args) == 3 else args[0]) + + # If reaches here, it means user code captured the exception + return self.next_variable(tx) super().call_method(tx, name, args, kwargs)