-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Open
Labels
module: __torch_function__module: dynamic shapestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Context
@soumith and I are investigating how to productionize torchdim. There are roughly two approaches:
- thread torchdim's new "Dim" object through C++
- re-implement torchdim using the torch_function mechanism
The second is lower effort (it is really difficult to add a new type to PYTorch), but has slower eager-mode performance. We would also need to change how torch_function works -- one can only use torch_function objects as arguments to Tensor inputs to torch.* operators. The rest of this issue assumes we take the latter approach.
thoughts? @zdevito @ezyang @albanD
Repro
import torch
class Dim:
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
dim = Dim()
x = torch.randn(3)
torch.sum(x, dim)
gives
TypeError: sum() received an invalid combination of arguments - got (Tensor, Dim), but expected one of
* (Tensor input, *, torch.dtype dtype)
* (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)
* (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)
Pitch
- torch_function objects can be marked as "tensor-like" or "anything goes". The "tensor-like" torch_function objects work like they do today: they may only be used in the place of Tensor arguments to torch APIs. The "anything goes" torch_function objects may be used in place of all arguments.
- At the extreme, this is like adding a schema to all torch APIs. For the codegenned torch.* APIs this is not difficult (we already have typing information), for the other torch.* ops there needs to be manual intervention.
Metadata
Metadata
Assignees
Labels
module: __torch_function__module: dynamic shapestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module