-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Add utility to get computed kernel in torch.library #158393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mikaylagawarecki/320/base
Are you sure you want to change the base?
Add utility to get computed kernel in torch.library #158393
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158393
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b224df3 with merge base 556e2a7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Related to #155330 [ghstack-poisoned]
Related to #155330 [ghstack-poisoned]
Related to #155330 cc albanD [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds pretty good!
Only small questions!
torch/library.py
Outdated
op = op._name | ||
|
||
if isinstance(dispatch_key, str): | ||
dispatch_key = torch._C.DispatchKey.__members__[dispatch_key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the error you get when passing a wrong dispatch key here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, added proper error handling here
auto [annotatedKernel, _] = | ||
computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); | ||
|
||
return SafeKernelFunction(&annotatedKernel.kernel); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to grab the debug string here and add that to the __repr__
we get from python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives something like
SafeKernelFunction(debug='registered at /data/users/mg1998/pytorch/build/aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp:2309')
Do you think that is meaningful enough or should we add more info
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is 100% super useful. These error messages saved me a few times for multiple-registration errors!
And for python users, it should point to their code directly. Which is even better so they know which function this is!
|
||
// List of tokens that need to be invalidated when this KernelFunction is | ||
// destroyed | ||
mutable std::vector<std::weak_ptr<KernelToken>> tokens_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why mutable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why weak_ptr and not shared_ptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why mutable
I think this is necessary in order to make registerToken
const, which was in turn needed to allow SafeKernelFunction
to take in const KernelFunction*
, removing this would necessitate const_cast-ing
the annotatedKernel.kernel
in getComputedKernelForDispatchKey
, wdyt
why weak_ptr and not shared_ptr
What's the benefit of shared_ptr
over weak_ptr
here? if we use weak_ptr
, the KernelToken
dies with the SafeKernelFunction
, which I think is what we want to achieve here (?)
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
torch/library.py
Outdated
>>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU") | ||
>>> | ||
>>> # Use the kernel to call the operator | ||
>>> op_handle = torch._C._dispatch_find_schema_or_throw("aten::add", "Tensor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we store the handle on the object returned by get_kernel so the user doesn't have to do this by any chance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I did that and made call_boxed
access this within its implementation
torch/library.py
Outdated
>>> op_handle = torch._C._dispatch_find_schema_or_throw("aten::add", "Tensor") | ||
>>> a = torch.tensor([1.0, 2.0]) | ||
>>> b = torch.tensor([3.0, 4.0]) | ||
>>> result = kernel.call_boxed(op_handle, torch._C.DispatchKeySet("CPU"), a, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have a keyset here? Isn't this supposed to be just a kernel? We're not really going through full dispatch here right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not really going through full dispatch here right
That was my expectation, unless I'm missing something about boxed functions..I'm unsure what the right thing to do here is though as seems callBoxed
and call
on KernelFunction always requires the dispatchkeyset.
It seems this is the calling convention (?) Is it possible that the KernelFunction we're getting here would still have to redispatch (?)
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds `OperatorEntry::getComputedKernelForDispatchKey` which returns the KernelFunction corresponding to `OperatorEntry.dispatchTable_[dispatch_ix]` for a given dispatch key - Specifically it returns a `SafeKernelFunction` that holds a `KernelToken`. This `KernelToken` is registered to the `KernelFunction` in `OperatorEntry.kernels_` and will be invalidated when the `KernelFunction` is destructed (i.e. when the `AnnotatedKernel` that holds this `KernelFunction` is removed from `kernels_`, which happens when the corresponding impl is deregistered). - `SafeKernelFunction` can be called via `callBoxed`, the validity of the token will be checked before this happens - `SafeKernelFunction` is pybinded and `getComputedKernelForDispatchKey` is exposed to the frontend ia `torch.library.get_kernel` Related to #155330 cc albanD [ghstack-poisoned]
Adds
OperatorEntry::getComputedKernelForDispatchKey
which returns the KernelFunction corresponding toOperatorEntry.dispatchTable_[dispatch_ix]
for a given dispatch keySafeKernelFunction
that holds aKernelToken
. ThisKernelToken
is registered to theKernelFunction
inOperatorEntry.kernels_
and will be invalidated when theKernelFunction
is destructed (i.e. when theAnnotatedKernel
that holds thisKernelFunction
is removed fromkernels_
, which happens when the corresponding impl is deregistered).SafeKernelFunction
can be called viacallBoxed
, the validity of the token will be checked before this happensSafeKernelFunction
is pybinded andgetComputedKernelForDispatchKey
is exposed to the frontend iatorch.library.get_kernel
Related to #155330
Stack from ghstack (oldest at bottom):
cc @albanD