Skip to content

[dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. #157566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

thenumberouscode
Copy link
Contributor

@thenumberouscode thenumberouscode commented Jul 3, 2025

inside torch.compile.disable function always triggers recompilation. because a user inside function decorated with torch._dynamo.disable would be used as an argument in the resume_in_xx function. In the current implementation, it will always be a new object, resulting in the ID_MATCH guard always failing and triggering recompilation.

Fixes #157399
@xmfan

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos

Copy link

pytorch-bot bot commented Jul 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157566

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 2b2bb11 with merge base f636736 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jul 3, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: thenumberouscode / name: zyl_keep_moving (2b2bb11)

@thenumberouscode
Copy link
Contributor Author

@pytorchbot label "release notes: dynamo"

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 4, 2025

@anijain2305 Could you please take a look at my code? I would really appreciate your feedback

@thenumberouscode thenumberouscode changed the title [dynamo] Fix the bug where using disable inside of compile always triggers recompilation. [dynamo] [guard] Change the guard type of inside disable function to avoid unnecessary recompilation. Jul 4, 2025
@thenumberouscode
Copy link
Contributor Author

@xmfan Could you please take a look at my code? I would really appreciate your feedback

@anijain2305 anijain2305 self-requested a review July 4, 2025 18:58
@@ -1387,7 +1387,9 @@ def as_python_constant(self):

@classmethod
def create_with_source(cls, value, source):
if not is_wrapper_or_member_descriptor(value):
if inspect.getattr_static(value, "_torchdynamo_disable", False):
install_guard(source.make_guard(GuardBuilder.TYPE_MATCH))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe install the FUNCTION_MATCH on the value._torchdynamo_disable. Something along the lines of


    fn_obj = value._torchdynamo_disable
    install_guard(AttrSource(source, "_torchdynamo_disable")).make_guard(GuardBuilder.FUNCTION_MATCH))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm new to dynamo so forgive silly question, but could we add a new rule to trace_rules.manual_torch_name_rule_map instead?

"torch._dynamo.disable": SkipFunctionVariable,

It wouldn't introduce additional code, but I didn't test it beside a happy path. @anijain2305 @thenumberouscode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe install the FUNCTION_MATCH on the value._torchdynamo_disable. Something along the lines of


    fn_obj = value._torchdynamo_disable
    install_guard(AttrSource(source, "_torchdynamo_disable")).make_guard(GuardBuilder.FUNCTION_MATCH))

Your code indeed offers a more granular solution. I'm just waiting for the results of all the workflows

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manual_torch_name_rule_map

Have you tested it in your local environment? I’m afraid it may not be effective in solving this problem.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just gave it a quick test to see if the bug occurs and it doesn’t, but didn’t test it further

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 I don't think this is sound - _torchdynamo_disable is only ever True right? At the root of the issue is that the ID of a local nested function is changing on each call, I don't think we can guarantee that it's semantically the same function though (which is why we usually check function ID). We could have any number of closures inside the function and pass them as arguments to the resume function based on different conditions and we wouldn't recompile the resume function even though we should.

Copy link
Contributor Author

@thenumberouscode thenumberouscode Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mlazos I’m a bit confused about something. To me, a function that doesn’t need to be compiled means we don’t have to worry about its implementation, arguments, or any other details; we just pass it to Python’s default interpreter. cc @anijain2305

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mlazos I think I understand what you're pointing out. You mean that my bug fix will cause all user functions decorated with _torchdynamo_disable not to be checked by ID whether they are nested functions or not, it is too wide open

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah your second comment is similar to what I'm saying. The reason we check functions by ID is that you can't check semantic equivalence of functions. Even with nested functions it still won't work. The main reason is that let's say you have different nested functions and you conditionally pass them to the resume function. It looks like this.

@torch.compile()
def outer(x, cond):
  @torch._dynamo.disable()
  def fn0(y):
    return y + 1
  
  @torch._dynamo.disable()
  def fn1(y):
    return y + 2

  if cond:
    f = fn0
  else:
    f = fn1

  torch._dynamo.graph_break()
  # there will be a resume function here
  return f()

In this case if you flip cond from True to False the outer function will recompile and the resume funciton after the graph break also needs to recompile, because its behavior will change.

In this case you need to guard on the ID of f even though it will change on every call to outer. If the resume function doesn't properly recompile if f changes, this code will be incorrect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, your understanding of the guard is very deep—I’ve learned a lot from you. Now, let's fix it according to your suggestions.

@thenumberouscode
Copy link
Contributor Author

@pytorchbot label "ciflow/trunk"

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 5, 2025
@thenumberouscode
Copy link
Contributor Author

@anijain2305 Hi, A workflow test has failed (https://github.com/pytorch/pytorch/actions/runs/16063185244/job/45383170740?pr=157566) due to a timeout, as indicated in the log. I suspect this issue is unrelated to my bug fix. Could we please retry the test? Thank you!

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jul 8, 2025
@thenumberouscode
Copy link
Contributor Author

@mlazos @anijain2305 I have added a cache for dynamically created user functions decorated with torch._dynamo.disable, allowing us to use the real user function IDs with the ID_MATCH guard without needing recompilation. Please check out my new code and share your feedback when you get a chance.

@thenumberouscode thenumberouscode changed the title [dynamo] [guard] Change the guard type of inside disable function to avoid unnecessary recompilation. [dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. Jul 8, 2025
@thenumberouscode thenumberouscode force-pushed the bugfix_recompile branch 2 times, most recently from 24410c2 to 7f8143d Compare July 8, 2025 08:04
@mlazos
Copy link
Contributor

mlazos commented Jul 8, 2025

@mlazos @anijain2305 I have added a cache for dynamically created user functions decorated with torch._dynamo.disable, allowing us to use the real user function IDs with the ID_MATCH guard without needing recompilation. Please check out my new code and share your feedback when you get a chance.

Can you explain a little more how it works? looking at the code it isn't clear to me why it works haha

@thenumberouscode
Copy link
Contributor Author

@mlazos @anijain2305 I have added a cache for dynamically created user functions decorated with torch._dynamo.disable, allowing us to use the real user function IDs with the ID_MATCH guard without needing recompilation. Please check out my new code and share your feedback when you get a chance.

Can you explain a little more how it works? looking at the code it isn't clear to me why it works haha

Sure, the call of resume functions always have the following format:
__resume_at_xx(torch._dynamo.disable(_create_nested_fn(my_fn, '__co_consts[4]', 'my_fn', None, None, None, None)), x)
The first argument will always be a new object, so we cannot check its ID. Instead, we can check my_fn, which is the user's function. That's why I added caching in the _create_nested_fn function; this ensures that the user function's ID does not change. Additionally, in torch._dynamo.disable, the original user function is recorded in _torchdynamo_orig_callable.

_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]

@mlazos
Copy link
Contributor

mlazos commented Jul 9, 2025

@mlazos @anijain2305 I have added a cache for dynamically created user functions decorated with torch._dynamo.disable, allowing us to use the real user function IDs with the ID_MATCH guard without needing recompilation. Please check out my new code and share your feedback when you get a chance.

Can you explain a little more how it works? looking at the code it isn't clear to me why it works haha

Sure, the call of resume functions always have the following format: __resume_at_xx(torch._dynamo.disable(_create_nested_fn(my_fn, '__co_consts[4]', 'my_fn', None, None, None, None)), x) The first argument will always be a new object, so we cannot check its ID. Instead, we can check my_fn, which is the user's function. That's why I added caching in the _create_nested_fn function; this ensures that the user function's ID does not change. Additionally, in torch._dynamo.disable, the original user function is recorded in _torchdynamo_orig_callable.

_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]

Ah this makes sense now, I didn't realize how _create_nested_fn was used, but basically you now will just return the same object any time the same code is passed in. So one question I have is when will the cache ever be cleared? we should clear it once the output code using that entry is deleted. An example finalizer I added is here.

We should also clear it if torch._dynamo.reset() is called. Otherwise this looks great!

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Jul 11, 2025

@mlazos @anijain2305 I have added a cache for dynamically created user functions decorated with torch._dynamo.disable, allowing us to use the real user function IDs with the ID_MATCH guard without needing recompilation. Please check out my new code and share your feedback when you get a chance.

Can you explain a little more how it works? looking at the code it isn't clear to me why it works haha

Sure, the call of resume functions always have the following format: __resume_at_xx(torch._dynamo.disable(_create_nested_fn(my_fn, '__co_consts[4]', 'my_fn', None, None, None, None)), x) The first argument will always be a new object, so we cannot check its ID. Instead, we can check my_fn, which is the user's function. That's why I added caching in the _create_nested_fn function; this ensures that the user function's ID does not change. Additionally, in torch._dynamo.disable, the original user function is recorded in _torchdynamo_orig_callable.

_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]

Ah this makes sense now, I didn't realize how _create_nested_fn was used, but basically you now will just return the same object any time the same code is passed in. So one question I have is when will the cache ever be cleared? we should clear it once the output code using that entry is deleted. An example finalizer I added is here.

We should also clear it if torch._dynamo.reset() is called. Otherwise this looks great!

@mlazos After conducting a local test, I've discovered that the finalizer in OutputGraph is not functioning in this scenario. Since our cache is shared across different OutputGraph instances, we cannot clear it during the exit process of a single graph. Our cache needs to persist throughout the entire compilation process, similar to how code caches operate.

To address this issue, we should implement a CreateNestedFnCache to manage all function caches, which can then be cleared using torch._dynamo.reset().

Please take a look at my new commit to see the implementations, and don't hesitate to share any feedback or suggestions!

@thenumberouscode
Copy link
Contributor Author

Looks good thanks!

@mlazos The workflow failed again for a unrelated lint problem. Could we just run the workflow again? thanks.
the failed workflows:

  1. https://github.com/pytorch/pytorch/actions/runs/16433277272/job/46683063908?pr=157566 (lint failure)
  2. https://github.com/pytorch/pytorch/actions/runs/16433277255/job/46683401413?pr=157566 (unstable workflow)

…o.disable to obtain their actual IDs, thereby avoiding ID_MATCH guard failures.
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jul 28, 2025
@thenumberouscode
Copy link
Contributor Author

@mlazos @anijain2305 Can we trigger another workflow run? Appreciate it.

@thenumberouscode
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yangw-dev
Copy link
Contributor

@pytorchbot revert -m "failed an odd internal test, please reach out to metamate to fix it, D79112610" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Aug 1, 2025
… function to avoid unnecessary recompilation. (#157566)"

This reverts commit 8e07c98.

Reverted #157566 on behalf of https://github.com/yangw-dev due to failed an odd internal test, please reach out to metamate to fix it, D79112610 ([comment](#157566 (comment)))
@pytorchmergebot
Copy link
Collaborator

@thenumberouscode your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Aug 1, 2025
@pytorch-bot pytorch-bot bot dismissed stale reviews from mlazos and anijain2305 August 1, 2025 01:27

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…n to avoid unnecessary recompilation. (#157566)

inside torch.compile.disable function always triggers recompilation. because a user inside function decorated with torch._dynamo.disable would be used as an argument in the resume_in_xx function. In the current implementation,  it will always be a new object, resulting in the ID_MATCH guard always failing and triggering recompilation.

Fixes #157399
@xmfan

Pull Request resolved: #157566
Approved by: https://github.com/mlazos, https://github.com/anijain2305
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
… function to avoid unnecessary recompilation. (#157566)"

This reverts commit 8e07c98.

Reverted #157566 on behalf of https://github.com/yangw-dev due to failed an odd internal test, please reach out to metamate to fix it, D79112610 ([comment](#157566 (comment)))
@thenumberouscode
Copy link
Contributor Author

@yangw-dev Which unit test failed? I'll fix it.

@mlazos
Copy link
Contributor

mlazos commented Aug 5, 2025

D79112610

Hi @yangw-dev can you point me to the test failure? I only see one build failure on that diff

@thenumberouscode
Copy link
Contributor Author

@yangw-dev Can you let me know which unit test failed? cc @mlazos

@thenumberouscode
Copy link
Contributor Author

@mlazos , since we don't have @yangw-dev help here, what should we do next? I've put a lot of time into this PR and really don't want it to go to waste.

@mlazos
Copy link
Contributor

mlazos commented Aug 11, 2025

@mlazos , since we don't have @yangw-dev help here, what should we do next? I've put a lot of time into this PR and really don't want it to go to waste.

Nw we'll figure this out, I was at an offsite last week so didn't have a chance to follow up. I'll take a look at this this week and try to debug the diff.

@thenumberouscode
Copy link
Contributor Author

thenumberouscode commented Aug 11, 2025

Nw we'll figure this out, I was at an offsite last week so didn't have a chance to follow up. I'll take a look at this this week and try to debug the diff.

Thanks a lot for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source release notes: dynamo Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[dynamo] using disable inside of compile always recompiles
8 participants