-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[Trace Python Dispatcher] Support FuncTorchInterpreter #144444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144444
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 19d87e8 with merge base d44c390 ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
x = torch.tensor([1, 2, 3, 4]) | ||
y = torch.tensor([10, 20, 30, 40]) | ||
self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 56])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you test that a second call to fn
does not recompile? I'm suspicious of the ID_MATCH (you should just be able to guard on the level, key, and metadata of the Interpreter)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it would recompile, because the interpreter's ID is changed if we call the compiled function again. Do you think we should add a new C++ guards for Interpreter
like the DispatchKeySet
I did in last PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, otherwise we are going to recompile every time a function that needs retrieve_current_functorch_interpreter
gets called. Every time vmap gets called a fresh interpreter is being created
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an easy way to add a Pythons guard on the .level, .key, and .batch_size, and .randomness of the INterpreter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon reviewing this again, I found my previous assessment to be incorrect. The recompilation is caused by a different reason, not the guard on the interpreter. This is because we always generate the interpreter within the compiled region, and there are no use cases where an interpreter is passed into the compiled region that would require guarding. A thorough search of the codebase confirmed this.
As such, it seems that only the sourceless builder is used in this case. We could either remove the builder when source is not None or leave it as is until we encounter a scenario where recompilation due to ID_MATCH
becomes a genuine issue. Because we can't find a proper unit test either.
elif isinstance( | ||
value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure that DispatchKey and TransformType actually behave like Python enums? (I don't care if they don't, this is close enough, but I'm just curious)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If they do then I want some easy way for a developer to add support for more of these C++ enums. Maybe shove them into a list somewhere near EnumVariable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, as DispatchKey
and TransformType
are pure python classes extending from enum.Enum
, even though they are bound to their corresponding C++ types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no action required here, might be nice to have an easy(ier) way to add more of these in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine, I want to check the recompilation and guarding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool I buy it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool I buy it
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot merge -f "irrelevant failure" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames