diff --git a/docarray/typing/tensor/torch_tensor.py b/docarray/typing/tensor/torch_tensor.py index 32d1cbff7b..dc5b0587c0 100644 --- a/docarray/typing/tensor/torch_tensor.py +++ b/docarray/typing/tensor/torch_tensor.py @@ -130,7 +130,7 @@ def validate( field: 'ModelField', config: 'BaseConfig', ) -> T: - if isinstance(value, TorchTensor): + if isinstance(value, cls): return cast(T, value) elif isinstance(value, torch.Tensor): return cls._docarray_from_native(value) @@ -195,10 +195,7 @@ def _docarray_from_native(cls: Type[T], value: torch.Tensor) -> T: :param value: the native `torch.Tensor` :return: a `TorchTensor` """ - if cls.__unparametrizedcls__: # This is not None if the tensor is parametrized - value.__class__ = cls.__unparametrizedcls__ # type: ignore - else: - value.__class__ = cls + value.__class__ = cls return cast(T, value) @classmethod @@ -254,11 +251,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # this tells torch to treat all of our custom tensors just like # torch.Tensor's. Otherwise, torch will complain that it doesn't # know how to handle our custom tensor type. - docarray_torch_tensors = TorchTensor.__subclasses__() + docarray_torch_tensors = cls.__subclasses__() + [cls] types_ = tuple( torch.Tensor if t in docarray_torch_tensors else t for t in types ) - return super().__torch_function__(func, types_, args, kwargs) + return torch.Tensor.__torch_function__(func, types_, args, kwargs) @classmethod def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T: