Skip to content

Commit e71853e

Browse files
committed
Update on "[dynamo, nested graph breaks] prevent excessive recompilations"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
2 parents 3a04f56 + 631cc0a commit e71853e

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

test/dynamo/test_nested_graph_breaks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,6 @@ def f3(x):
304304
self.assertEqual(cnts.frame_count, 3)
305305
self.assertEqual(cnts.op_count, 7)
306306

307-
@unittest.expectedFailure
308307
@torch._dynamo.config.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
309308
def test_no_recompiles(self):
310309
global f1, f2, f3

torch/_dynamo/resume_execution.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ def generate(
311311
stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
312312
argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
313313
null_idxes: tuple[int, ...],
314-
has_nested: bool,
314+
# mainly used to ensure distinct code objects per stack trace,
315+
# which prevents excessive recompilation of inner frames
316+
nested_code_objs: tuple[types.CodeType],
315317
) -> types.CodeType:
316318
assert offset is not None
317319
assert not (
@@ -332,7 +334,7 @@ def generate(
332334
stack_ctx_vars,
333335
argnames_ctx_vars,
334336
null_idxes,
335-
has_nested,
337+
nested_code_objs,
336338
)
337339

338340
is_py311_plus = sys.version_info >= (3, 11)
@@ -466,7 +468,7 @@ def update(
466468
)
467469

468470
# Call nested resume function
469-
if has_nested:
471+
if nested_code_objs:
470472
prefix.extend(
471473
[
472474
# set up __nested_resume_fns[-1] call

torch/_dynamo/symbolic_convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2506,7 +2506,7 @@ def create_call_resume_at(self, inst, all_stack_locals_metadata):
25062506

25072507
# build the resume function for each frame
25082508
resume_names = []
2509-
resume_codes = []
2509+
resume_codes: list[types.CodeType] = []
25102510
for i, meta in enumerate(all_stack_locals_metadata):
25112511
cur_tx = txes[i]
25122512
if cur_tx is self:
@@ -2594,7 +2594,7 @@ def create_call_resume_at(self, inst, all_stack_locals_metadata):
25942594
tuple(meta.stack_ctx_args),
25952595
tuple(meta.locals_ctx_args),
25962596
tuple(meta.stack_null_idxes),
2597-
self is not cur_tx,
2597+
tuple(resume_codes),
25982598
)
25992599
resume_codes.append(new_code)
26002600

0 commit comments

Comments
 (0)