Skip to content

Add utility to get computed kernel in torch.library #158393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: gh/mikaylagawarecki/320/base
Choose a base branch
from

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Jul 15, 2025

Adds OperatorEntry::getComputedKernelForDispatchKey which returns the KernelFunction corresponding to OperatorEntry.dispatchTable_[dispatch_ix] for a given dispatch key

  • Specifically it returns a SafeKernelFunction that holds a KernelToken. This KernelToken is registered to the KernelFunction in OperatorEntry.kernels_ and will be invalidated when the KernelFunction is destructed (i.e. when the AnnotatedKernel that holds this KernelFunction is removed from kernels_, which happens when the corresponding impl is deregistered).
  • SafeKernelFunction can be called via callBoxed, the validity of the token will be checked before this happens
  • SafeKernelFunction is pybinded and getComputedKernelForDispatchKey is exposed to the frontend ia torch.library.get_kernel

Related to #155330

Stack from ghstack (oldest at bottom):

cc @albanD

Copy link

pytorch-bot bot commented Jul 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158393

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b224df3 with merge base 556e2a7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

mikaylagawarecki added a commit that referenced this pull request Jul 15, 2025
ghstack-source-id: c7793b0
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki marked this pull request as draft July 16, 2025 14:44
mikaylagawarecki added a commit that referenced this pull request Jul 16, 2025
ghstack-source-id: 439450f
Pull Request resolved: #158393
mikaylagawarecki added a commit that referenced this pull request Jul 24, 2025
ghstack-source-id: e876743
Pull Request resolved: #158393
mikaylagawarecki added a commit that referenced this pull request Jul 24, 2025
ghstack-source-id: 762f9b6
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki added module: python frontend For issues relating to PyTorch's Python frontend topic: new features topic category labels Jul 24, 2025
@mikaylagawarecki mikaylagawarecki requested a review from albanD July 24, 2025 21:35
@mikaylagawarecki mikaylagawarecki added release notes: python_frontend python frontend release notes category and removed module: python frontend For issues relating to PyTorch's Python frontend labels Jul 24, 2025
@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review July 24, 2025 21:37
mikaylagawarecki added a commit that referenced this pull request Jul 24, 2025
ghstack-source-id: 5925282
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki requested a review from zou3519 July 25, 2025 14:48
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds pretty good!
Only small questions!

torch/library.py Outdated
op = op._name

if isinstance(dispatch_key, str):
dispatch_key = torch._C.DispatchKey.__members__[dispatch_key]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error you get when passing a wrong dispatch key here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, added proper error handling here

auto [annotatedKernel, _] =
computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);

return SafeKernelFunction(&annotatedKernel.kernel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to grab the debug string here and add that to the __repr__ we get from python?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives something like

SafeKernelFunction(debug='registered at /data/users/mg1998/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp:2309')

Do you think that is meaningful enough or should we add more info

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is 100% super useful. These error messages saved me a few times for multiple-registration errors!
And for python users, it should point to their code directly. Which is even better so they know which function this is!


// List of tokens that need to be invalidated when this KernelFunction is
// destroyed
mutable std::vector<std::weak_ptr<KernelToken>> tokens_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why mutable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why weak_ptr and not shared_ptr?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why mutable

I think this is necessary in order to make registerToken const, which was in turn needed to allow SafeKernelFunction to take in const KernelFunction*, removing this would necessitate const_cast-ing the annotatedKernel.kernel in getComputedKernelForDispatchKey, wdyt

why weak_ptr and not shared_ptr

What's the benefit of shared_ptr over weak_ptr here? if we use weak_ptr, the KernelToken dies with the SafeKernelFunction, which I think is what we want to achieve here (?)

Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: 1c15180
Pull Request resolved: #158393
@mikaylagawarecki mikaylagawarecki requested a review from albanD August 5, 2025 19:32
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: ec2351a
Pull Request resolved: #158393
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: 72c37fb
Pull Request resolved: #158393
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: 8f11139
Pull Request resolved: #158393
torch/library.py Outdated
>>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")
>>>
>>> # Use the kernel to call the operator
>>> op_handle = torch._C._dispatch_find_schema_or_throw("aten::add", "Tensor")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we store the handle on the object returned by get_kernel so the user doesn't have to do this by any chance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I did that and made call_boxed access this within its implementation

torch/library.py Outdated
>>> op_handle = torch._C._dispatch_find_schema_or_throw("aten::add", "Tensor")
>>> a = torch.tensor([1.0, 2.0])
>>> b = torch.tensor([3.0, 4.0])
>>> result = kernel.call_boxed(op_handle, torch._C.DispatchKeySet("CPU"), a, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have a keyset here? Isn't this supposed to be just a kernel? We're not really going through full dispatch here right?

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not really going through full dispatch here right

That was my expectation, unless I'm missing something about boxed functions..I'm unsure what the right thing to do here is though as seems callBoxed and call on KernelFunction always requires the dispatchkeyset.

It seems this is the calling convention (?) Is it possible that the KernelFunction we're getting here would still have to redispatch (?)

Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 6, 2025
ghstack-source-id: 502dccd
Pull Request resolved: #158393
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 8, 2025
ghstack-source-id: a752cca
Pull Request resolved: #158393
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key
- Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered).
- `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens
- `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel`

Related to #155330




cc albanD

[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki requested a review from albanD August 8, 2025 17:59
mikaylagawarecki added a commit that referenced this pull request Aug 8, 2025
ghstack-source-id: a4e17da
Pull Request resolved: #158393
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
release notes: python_frontend python frontend release notes category topic: new features topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants