Skip to content

Implement __torch_function__ to let Tensor-like objects override torch functions #24015

@rgommers

Description

@rgommers

🚀 Feature

Note, this is one of the features discussed in gh-22402, see that issue for full context.

Implement a __torch_function__ method on Tensor (public) and a torch_function_dispatch decorator (private). The method contains a dispatch mechanism that can be used to dispatch to Tensor-like objects, that can then handle the torch.somefunction call the way they see fit. I.e. torch functions that take at least one tensor input parameter become overridable.

Should be able to do this without any additional overhead when the input is Tensor, and sub-microsecond when it is Tensor-like.

The mechanism is completely analogous to NumPy's __array_ufunc__ and __array_function__ methods.

Motivation

torch functions need to become overridable. One concrete user that @ezyang mentioned is NestedTensor, and there will be others. See also gh-22402.

Plan

First build a prototype and apply it to only a couple of functions in the main torch namespace (to review/evaluate). Make sure they have different signatures from each other. E.g. max, dot, svd.

Use a toy Tensor ducktype class, along the lines of DiagonalArray in https://numpy.org/devdocs/user/basics.dispatch.html, to implement a working example with the functions that are overridden.

  • First make it work (parts can be in Python) with the couple of functions chosen
  • Then make it fast - reuse the existing checks to ensure zero overhead on func(Tensor, ...)- needs to all be in C++
  • Once that's good, expand coverage to the whole API.

Metadata

Metadata

Labels

featureA request for a proper, new feature.high prioritymodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions