Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fix #203
  • Loading branch information
justinchuby authored Sep 18, 2024
1 parent 8bab0f3 commit a9aba9e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 9 deletions.
18 changes: 13 additions & 5 deletions src/torch_onnx/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
torch.int64: ir.DataType.INT64,
torch.int8: ir.DataType.INT8,
torch.uint8: ir.DataType.UINT8,
torch.uint16: ir.DataType.UINT16,
torch.uint32: ir.DataType.UINT32,
torch.uint64: ir.DataType.UINT64,
}
_BLUE = "\033[96m"
_END = "\033[0m"
Expand Down Expand Up @@ -97,20 +100,25 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None):
tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype), name=name
)

def __array__(self, dtype: Any = None) -> np.ndarray:
# numpy() calls __array__ in ir.Tensor
def numpy(self) -> np.ndarray:
self.raw: torch.Tensor
if self.dtype == ir.DataType.BFLOAT16:
return self.raw.view(torch.uint16).numpy(force=True).__array__(dtype)
return self.raw.view(torch.uint16).numpy(force=True)
if self.dtype in {
ir.DataType.FLOAT8E4M3FN,
ir.DataType.FLOAT8E4M3FNUZ,
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ,
}:
# TODO: Use ml_dtypes
return self.raw.view(torch.uint8).numpy(force=True).__array__(dtype)
return self.raw.numpy(force=True).__array__(dtype)
return self.raw.view(torch.uint8).numpy(force=True)
return self.raw.numpy(force=True)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
del copy # Unused, but needed for the signature
if dtype is None:
return self.numpy()
return self.numpy().__array__(dtype)

def tobytes(self) -> bytes:
# Implement tobytes to support native PyTorch types so we can use types like bloat16
Expand Down
72 changes: 68 additions & 4 deletions src/torch_onnx/_core_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# ruff: noqa: UP037
from __future__ import annotations
# Owner(s): ["module: onnx"]
"""Unit tests for the _core module."""

import unittest
from __future__ import annotations

import numpy as np
import torch
from onnxscript import FLOAT
from torch.testing._internal import common_utils

import torch_onnx
from torch_onnx import _core
Expand All @@ -27,7 +30,68 @@
u64 = torch.uint64


class ExportedProgramToIrTest(unittest.TestCase):
@common_utils.instantiate_parametrized_tests
class TorchTensorTest(common_utils.TestCase):
@common_utils.parametrize(
"dtype, np_dtype",
[
(torch.bfloat16, np.uint16),
(torch.bool, np.bool_),
(torch.complex128, np.complex128),
(torch.complex64, np.complex64),
(torch.float16, np.float16),
(torch.float32, np.float32),
(torch.float64, np.float64),
(torch.float8_e4m3fn, np.uint8),
(torch.float8_e4m3fnuz, np.uint8),
(torch.float8_e5m2, np.uint8),
(torch.float8_e5m2fnuz, np.uint8),
(torch.int16, np.int16),
(torch.int32, np.int32),
(torch.int64, np.int64),
(torch.int8, np.int8),
(torch.uint16, np.uint16),
(torch.uint32, np.uint32),
(torch.uint64, np.uint64),
(torch.uint8, np.uint8),
],
)
def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
self.assertEqual(tensor.numpy().dtype, np_dtype)
self.assertEqual(tensor.__array__().dtype, np_dtype)
self.assertEqual(np.array(tensor).dtype, np_dtype)

@common_utils.parametrize(
"dtype",
[
(torch.bfloat16),
(torch.bool),
(torch.complex128),
(torch.complex64),
(torch.float16),
(torch.float32),
(torch.float64),
(torch.float8_e4m3fn),
(torch.float8_e4m3fnuz),
(torch.float8_e5m2),
(torch.float8_e5m2fnuz),
(torch.int16),
(torch.int32),
(torch.int64),
(torch.int8),
(torch.uint16),
(torch.uint32),
(torch.uint64),
(torch.uint8),
],
)
def test_tobytes(self, dtype: torch.dtype):
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())


class ExportedProgramToIrTest(common_utils.TestCase):
def test_output_metadata_with_tuple_outputs(self):
class GraphModule(torch.nn.Module):
def forward(
Expand Down Expand Up @@ -85,4 +149,4 @@ def forward(self, arg0_1: "f32[3, 5, 5]"):


if __name__ == "__main__":
unittest.main()
common_utils.run_tests()

0 comments on commit a9aba9e

Please sign in to comment.