-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🚀 Feature
Note, this is one of the features discussed in gh-22402, see that issue for full context.
Implement a __torch_function__
method on Tensor
(public) and a torch_function_dispatch
decorator (private). The method contains a dispatch mechanism that can be used to dispatch to Tensor
-like objects, that can then handle the torch.somefunction
call the way they see fit. I.e. torch functions that take at least one tensor input parameter become overridable.
Should be able to do this without any additional overhead when the input is Tensor
, and sub-microsecond when it is Tensor
-like.
The mechanism is completely analogous to NumPy's __array_ufunc__
and __array_function__
methods.
Motivation
torch
functions need to become overridable. One concrete user that @ezyang mentioned is NestedTensor
, and there will be others. See also gh-22402.
Plan
First build a prototype and apply it to only a couple of functions in the main torch
namespace (to review/evaluate). Make sure they have different signatures from each other. E.g. max
, dot
, svd
.
Use a toy Tensor
ducktype class, along the lines of DiagonalArray
in https://numpy.org/devdocs/user/basics.dispatch.html, to implement a working example with the functions that are overridden.
- First make it work (parts can be in Python) with the couple of functions chosen
- Then make it fast - reuse the existing checks to ensure zero overhead on
func(Tensor, ...)
- needs to all be in C++ - Once that's good, expand coverage to the whole API.