Skip to content

[generator] Close all open generators in compile_subgraph #157149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: gh/guilhermeleobas/193/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/dynamo/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def test_close_with_side_effects(self):
z = 0

def whoo(t):
nonlocal z
nonlocal z # noqa: F824
try:
L.append(1)
yield t.sin()
Expand Down Expand Up @@ -1050,7 +1050,6 @@ def whoo(t):

@torch.compile(backend="eager", fullgraph=True)
def fn(t):
nonlocal z
gen = whoo(t)
i = next(gen)
y = gen.close()
Expand Down Expand Up @@ -1078,7 +1077,6 @@ def whoo(t):

@torch.compile(backend="eager", fullgraph=fullgraph)
def fn(t):
nonlocal z
gen = whoo(t)
i = next(gen)
gen.close()
Expand Down Expand Up @@ -1380,8 +1378,10 @@ def fn(t):
a = next(gen)
try:
gen.throw(ValueError)
except StopIteration:
except StopIteration as e:
assert len(e.args) == 0
return a
raise AssertionError("Expected StopIteration")

t = torch.randn(2)
y = self._compile_check(fn, (t,))
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,11 @@ def compile_subgraph(
)
self.codegen_suffix(tx, stack_values_flat, pass1)

# Close all generators opened while tracing. Needs to be done after
# pass1, as PyCodegen might try to reconstruct the generator, which
# sets LocalGeneratorObjectVariable.remaining_items
self.side_effects.close_local_generators()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment about why this is here.
also TODO(rzou): is this actually the right place?


# Use `pass1.uses` to selectively cache multi-user variables into a
# temporary local source. This (a). speeds up loading VTs with long
# chained source, and (b). avoids redundantly saving single-user VT
Expand Down
19 changes: 19 additions & 0 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
if TYPE_CHECKING:
from torch._dynamo.output_graph import OutputGraph
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
from torch._dynamo.variables.functions import LocalGeneratorFunctionVariable
from torch._dynamo.variables.lists import ListVariable


Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
self.keepalive = keepalive or []
self.save_for_backward = save_for_backward or []
self.tensor_hooks = tensor_hooks or {}
self.local_generators: list[LocalGeneratorFunctionVariable] = []
# Used by MappingProxyVariable to graph break in case of any mutated
# dict
self._has_existing_dict_mutation = False
Expand Down Expand Up @@ -228,6 +230,23 @@ def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool:
and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
)

def track_generator(self, gen: "LocalGeneratorFunctionVariable") -> None:
self.local_generators.append(gen)

def untrack_generator(self, gen: "LocalGeneratorFunctionVariable") -> None:
self.local_generators.remove(gen)

def close_local_generators(self) -> None:
from .symbolic_convert import temporarily_allow_writes_to_output_graph

output_graph = self.output_graph_weakref()
if output_graph:
tx = output_graph.root_tx
with temporarily_allow_writes_to_output_graph(tx):
for gen in self.local_generators:
if not gen.is_generator_exhausted():
gen.call_method(tx, "close", [], {})

def is_reconstructing_generator(self) -> bool:
output_graph = self.output_graph_weakref()

Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"):


@contextlib.contextmanager
def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"):
def temporarily_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"):
try:
tmp = tx.output.should_exit
tx.output.should_exit = False
Expand Down Expand Up @@ -1019,7 +1019,7 @@ class ExceptionStack:
# and "stack" sometimes refers to a C variable with the same name and the
# exception stack, respectively.
#
# The lifetime of an exception is (Python 3.11+):
# The lifetime of an exception in Python 3.11+ is:
# + tx._raise_exception_variable(...) := sets the current_exception variable
# + PUSH_EXC_INFO := pushes the current_exception to the *exception stack*
# + POP_EXCEPT := pops TOS from the *exception stack*
Expand Down
117 changes: 24 additions & 93 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
get_dynamo_observed_exception,
handle_observed_exception,
InfiniteGeneratorError,
ObservedException,
Expand Down Expand Up @@ -629,6 +628,8 @@ def __init__(
self.f_globals = f_globals
self.inline_tracer = inline_tracer

inline_tracer.output.side_effects.track_generator(self)

def get_code(self):
return self.code

Expand Down Expand Up @@ -657,13 +658,13 @@ def reconstruct(self, codegen: "PyCodegen"):
from torch._dynamo.symbolic_convert import (
InstructionTranslator,
save_and_restart_speculation_log,
temporarely_allow_writes_to_output_graph,
temporarily_allow_writes_to_output_graph,
)

tx = InstructionTranslator.current_tx()
save = save_and_restart_speculation_log(tx)
disallow = disallow_side_effects_in_generator(tx)
temp = temporarely_allow_writes_to_output_graph(tx)
temp = temporarily_allow_writes_to_output_graph(tx)

with save, disallow, temp:
tracer = self._get_inline_tracer(tx)
Expand Down Expand Up @@ -692,7 +693,7 @@ def _get_inline_tracer(self, tx):
def next_variable(self, tx):
tracer = self._get_inline_tracer(tx)

if self._is_generator_exhausted():
if self.is_generator_exhausted():
raise_observed_exception(StopIteration, tx)

try:
Expand All @@ -701,9 +702,12 @@ def next_variable(self, tx):
# for Dynamo to behave correctly
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
return tracer.inline_call_()
except ObservedException as e:
except ObservedUserStopIteration:
tracer.output.side_effects.untrack_generator(self)
raise
except ObservedException:
tracer.generator_exhausted = True
raise e
raise
except InfiniteGeneratorError:
# test/dynamo/test_misc.py::test_iterator_limit
raise
Expand Down Expand Up @@ -737,7 +741,8 @@ def force_apply_to_var_sequence(self, tx, fn) -> None:
handle_observed_exception(tx)
break

def _setup_exception(self, tx, exc):
def _setup_and_raise_exception(self, tx, exc):
# Raise an exception at the point where the generator is paused
tracer = self._get_inline_tracer(tx)
try:
tracer._raise_exception_variable(exc)
Expand All @@ -749,7 +754,7 @@ def _setup_exception(self, tx, exc):
def _is_generator_just_started(self):
return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0

def _is_generator_exhausted(self):
def is_generator_exhausted(self):
return getattr(self.inline_tracer, "generator_exhausted", False)

def call_method(
Expand Down Expand Up @@ -794,14 +799,14 @@ def call_method(
# See test GeneratorCloseCpythonTests::test_close_not_started

tracer = self._get_inline_tracer(tx)
if self._is_generator_just_started() or self._is_generator_exhausted():
if self._is_generator_just_started() or self.is_generator_exhausted():
tracer.generator_exhausted = True
return variables.ConstantVariable(None)

# Raise GeneratorExit to see if user code catches it. Any other exception
# is propagated to the parent frame.
try:
self._setup_exception(
self._setup_and_raise_exception(
tx, variables.ExceptionVariable(GeneratorExit, ())
)
# There's an extra block on Python 3.12+ to handle StopIteration
Expand Down Expand Up @@ -850,93 +855,19 @@ def call_method(
# returns the next value yielded by the generator.
# * If the generator exits without yielding, raise StopIteration
# * If the generator function does not catch the passed-in exception,
# or raises a different exception, then that exception propagates to the caller.
# or raises a different exception, then that new exception propagates to the caller.

# Setup the exception table and jump target in case of try...finally
tracer = self._get_inline_tracer(tx)
try:
# In Python 3.9, the exception is represented as a triple (typ, val, tb)
# In such cases, we re-raise the exception object given to avoid
# creating a new object, so that IS_OP works.
# See: https://github.com/pytorch/pytorch/pull/146496
self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
except ObservedException: # noqa: TRY203
# propagate the exception back to the parent caller
raise

retval = self.next_variable(tx)

# The exception raised before is still active. We need to check the exception
# table one more time to find the next target. But why? Let’s walk
# through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
#
# z = 0
# def whoo():
# global z
# z = 0
# try:
# yield 1
# except ValueError:
# yield 2
# finally:
# z += 1
# z += 10
#
# gen = whoo()
# next(gen)
# gen.throw(ValueError)
# print('z', z) -> z = 1
#
# ...
# >> 58 PUSH_EXC_INFO
#
# 8 60 LOAD_GLOBAL 2 (ValueError)
# 70 CHECK_EXC_MATCH
# 72 POP_JUMP_IF_FALSE 7 (to 88)
# 74 POP_TOP
#
# 9 76 LOAD_CONST 3 (2)
# 78 YIELD_VALUE 3 <------ ValueError is still active here
# 80 RESUME 1
# 82 POP_TOP
# 84 POP_EXCEPT
# 86 jump_backward 34 (to 20)
# ...
#
# ExceptionTable:
# 4 to 8 -> 124 [0] lasti
# 12 to 18 -> 58 [0]
# 20 to 56 -> 124 [0] lasti
# 58 to 82 -> 90 [1] lasti <------ move to 90
# 84 to 86 -> 96 [0]
# 88 to 88 -> 90 [1] lasti
# 90 to 94 -> 96 [0]
# 96 to 116 -> 118 [1] lasti
# 118 to 122 -> 124 [0] lasti
#
# In this scenario, a generator can yield after `throw()` is called. Even
# after the exception is raised a few lines above, it remains active
# within the `78 YIELD_VALUE` instruction. When the generator resumes
# after the second yield on instruction `80 RESUME`, we cannot simply
# return the control flow to the next instruction. Instead, one must
# check the exception table (or equivalent) to find the next target
# In this case, it says the instruction pointer must be moved to 90.
#
# Without this step, if we let the trace proceed to the next
# instruction, it would follow the control flow where the exception
# raised by `throw()` was handled and swallowed, potentially leading
# to incorrect behavior.
exc_type = type("__InternalThrowException", (Exception,), {})

try:
self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
self.next_variable(tx)
except get_dynamo_observed_exception(exc_type):
# We should get back the exception raised before.
pass
else:
raise_observed_exception(RuntimeError, tracer)
return retval
# In Python 3.9, the exception is represented as a triple (typ, val, tb)
# In such cases, we raise the given object instead of creating a new
# one, so that IS_OP works.
# See: https://github.com/pytorch/pytorch/pull/146496
self._setup_and_raise_exception(tx, args[1] if len(args) == 3 else args[0])

# If reaches here, it means user code captured the exception
return self.next_variable(tx)

super().call_method(tx, name, args, kwargs)

Expand Down
Loading