diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 4e5aeb5..e807b02 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,4 +4,5 @@ Change Logs 0.2.0 +++++ -* :pr:`3`: fixes Array API with onnxruntime +* :pr:`17`: implements ArrayAPI +* :pr:`3`: fixes Array API with onnxruntime and scikit-learn diff --git a/_doc/api/array_api.rst b/_doc/api/array_api.rst new file mode 100644 index 0000000..f07716a --- /dev/null +++ b/_doc/api/array_api.rst @@ -0,0 +1,7 @@ +onnx_array_api.array_api +======================== + +.. toctree:: + + array_api_onnx_numpy + array_api_onnx_ort diff --git a/_doc/api/array_api_numpy.rst b/_doc/api/array_api_numpy.rst new file mode 100644 index 0000000..f57089a --- /dev/null +++ b/_doc/api/array_api_numpy.rst @@ -0,0 +1,5 @@ +onnx_array_api.array_api.onnx_numpy +============================================= + +.. automodule:: onnx_array_api.array_api.onnx_numpy + :members: diff --git a/_doc/api/array_api_ort.rst b/_doc/api/array_api_ort.rst new file mode 100644 index 0000000..cc21311 --- /dev/null +++ b/_doc/api/array_api_ort.rst @@ -0,0 +1,5 @@ +onnx_array_api.array_api.onnx_ort +================================= + +.. automodule:: onnx_array_api.array_api.onnx_ort + :members: diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 7750a5b..75c0aa4 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -6,6 +6,7 @@ API .. toctree:: :maxdepth: 1 + array_api npx_functions npx_var npx_jit diff --git a/_doc/api/npx_annot.rst b/_doc/api/npx_annot.rst index d7e46e3..43de2d7 100644 --- a/_doc/api/npx_annot.rst +++ b/_doc/api/npx_annot.rst @@ -1,29 +1,54 @@ +============= npx.npx_types ============= +DType +===== + +.. autoclass:: onnx_array_api.npx.npx_types.DType + :members: + Annotations -+++++++++++ +=========== + +ElemType +++++++++ .. autoclass:: onnx_array_api.npx.npx_types.ElemType :members: +ParType ++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.ParType :members: +OptParType +++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.OptParType :members: +TensorType +++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.TensorType :members: +SequenceType +++++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.SequenceType :members: +TupleType ++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.TupleType :members: Shortcuts -+++++++++ +========= .. autoclass:: onnx_array_api.npx.npx_types.Bool diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh new file mode 100644 index 0000000..b32ee41 --- /dev/null +++ b/_unittests/test_array_api.sh @@ -0,0 +1,2 @@ +export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros \ No newline at end of file diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py new file mode 100644 index 0000000..30e2ca2 --- /dev/null +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -0,0 +1,20 @@ +import unittest +import numpy as np +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.array_api import onnx_numpy as xp +from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor + + +class TestOnnxNumpy(ExtTestCase): + def test_abs(self): + c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.zeros(c, dtype=xp.int64) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + a = xp.absolute(mat) + self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_npx/test_npx.py b/_unittests/ut_npx/test_npx.py index f550896..c9ee35f 100644 --- a/_unittests/ut_npx/test_npx.py +++ b/_unittests/ut_npx/test_npx.py @@ -29,6 +29,7 @@ npxapi_inline, ) from onnx_array_api.npx.npx_functions import absolute as absolute_inline +from onnx_array_api.npx.npx_functions import all as all_inline from onnx_array_api.npx.npx_functions import arange as arange_inline from onnx_array_api.npx.npx_functions import arccos as arccos_inline from onnx_array_api.npx.npx_functions import arccosh as arccosh_inline @@ -50,6 +51,7 @@ from onnx_array_api.npx.npx_functions import det as det_inline from onnx_array_api.npx.npx_functions import dot as dot_inline from onnx_array_api.npx.npx_functions import einsum as einsum_inline +from onnx_array_api.npx.npx_functions import equal as equal_inline from onnx_array_api.npx.npx_functions import erf as erf_inline from onnx_array_api.npx.npx_functions import exp as exp_inline from onnx_array_api.npx.npx_functions import expand_dims as expand_dims_inline @@ -95,6 +97,7 @@ from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor from onnx_array_api.npx.npx_types import ( Bool, + DType, Float32, Float64, Int64, @@ -127,18 +130,25 @@ def test_tensor(self): self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) self.assertEmpty(dt.shape) self.assertEqual(dt.type_name(), "TensorType['float32']") + dt = TensorType["float32"] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) self.assertEqual(dt.type_name(), "TensorType['float32']") + dt = TensorType[np.float32] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) self.assertEqual(dt.type_name(), "TensorType['float32']") self.assertEmpty(dt.shape) + dt = TensorType[np.str_] + self.assertEqual(len(dt.dtypes), 1) + self.assertEqual(dt.dtypes[0].dtype, ElemType.str_) + self.assertEqual(dt.type_name(), "TensorType[strings]") + self.assertEmpty(dt.shape) + self.assertRaise(lambda: TensorType[None], TypeError) - self.assertRaise(lambda: TensorType[np.str_], TypeError) self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError) def test_superset(self): @@ -1155,6 +1165,16 @@ def test_astype(self): got = ref.run(None, {"A": x}) self.assertEqualArray(z, got[0]) + def test_astype_dtype(self): + f = absolute_inline(copy_inline(Input("A")).astype(DType(7))) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Float64[None]}) + x = np.array([[-5.4, 6.6]], dtype=np.float64) + z = np.abs(x.astype(np.int64)) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": x}) + self.assertEqualArray(z, got[0]) + def test_astype_int(self): f = absolute_inline(copy_inline(Input("A")).astype(1)) self.assertIsInstance(f, Var) @@ -1413,6 +1433,9 @@ def test_einsum(self): lambda x, y: np.einsum(equation, x, y), ) + def test_equal(self): + self.common_test_inline_bin(equal_inline, np.equal) + @unittest.skipIf(scipy is None, reason="scipy is not installed.") def test_erf(self): self.common_test_inline(erf_inline, scipy.special.erf) @@ -1460,7 +1483,17 @@ def test_hstack(self): def test_identity(self): f = identity_inline(2, dtype=np.float64) onx = f.to_onnx(constraints={(0, False): Float64[None]}) - z = np.identity(2) + self.assertIn('name: "dtype"', str(onx)) + z = np.identity(2).astype(np.float64) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {}) + self.assertEqualArray(z, got[0]) + + def test_identity_uint8(self): + f = identity_inline(2, dtype=np.uint8) + onx = f.to_onnx(constraints={(0, False): Float64[None]}) + self.assertIn('name: "dtype"', str(onx)) + z = np.identity(2).astype(np.uint8) ref = ReferenceEvaluator(onx) got = ref.run(None, {}) self.assertEqualArray(z, got[0]) @@ -2318,7 +2351,7 @@ def compute_labels(X, centers): self.assertEqual(f.n_versions, 1) self.assertEqual(len(f.available_versions), 1) self.assertEqual(f.available_versions, [((np.float64, 2), (np.float64, 2))]) - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2)) + key = ((DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2)) onx = f.get_onnx(key) self.assertIsInstance(onx, ModelProto) self.assertRaise(lambda: f.get_onnx(2), ValueError) @@ -2379,7 +2412,12 @@ def compute_labels(X, centers, use_sqrt=False): self.assertEqualArray(got[1], dist) self.assertEqual(f.n_versions, 1) self.assertEqual(len(f.available_versions), 1) - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2), "use_sqrt", True) + key = ( + (DType(TensorProto.DOUBLE), 2), + (DType(TensorProto.DOUBLE), 2), + "use_sqrt", + True, + ) self.assertEqual(f.available_versions, [key]) onx = f.get_onnx(key) self.assertIsInstance(onx, ModelProto) @@ -2452,7 +2490,52 @@ def test_take(self): got = ref.run(None, {"A": data, "B": indices}) self.assertEqualArray(y, got[0]) + def test_numpy_all(self): + data = np.array([[1, 0], [1, 1]]).astype(np.bool_) + y = np.all(data, axis=1) + + f = all_inline(Input("A"), axis=1) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + + def test_numpy_all_empty(self): + data = np.zeros((0,), dtype=np.bool_) + y = np.all(data) + + f = all_inline(Input("A")) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + + @unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0") + def test_numpy_all_empty_axis_0(self): + data = np.zeros((0, 1), dtype=np.bool_) + y = np.all(data, axis=0) + + f = all_inline(Input("A"), axis=0) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + + def test_numpy_all_empty_axis_1(self): + data = np.zeros((0, 1), dtype=np.bool_) + y = np.all(data, axis=1) + + f = all_inline(Input("A"), axis=1) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + if __name__ == "__main__": - TestNpx().test_take() + # TestNpx().test_numpy_all_empty_axis_0() unittest.main(verbosity=2) diff --git a/_unittests/ut_npx/test_sklearn_array_api.py b/_unittests/ut_npx/test_sklearn_array_api.py index 016a170..79120a9 100644 --- a/_unittests/ut_npx/test_sklearn_array_api.py +++ b/_unittests/ut_npx/test_sklearn_array_api.py @@ -4,7 +4,7 @@ from onnx.defs import onnx_opset_version from sklearn import config_context, __version__ as sklearn_version from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor @@ -16,6 +16,7 @@ class TestSklearnArrayAPI(ExtTestCase): Version(sklearn_version) <= Version("1.2.2"), reason="reshape ArrayAPI not followed", ) + @ignore_warnings(DeprecationWarning) def test_sklearn_array_api_linear_discriminant(self): X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) y = np.array([1, 1, 1, 2, 2, 2]) @@ -26,6 +27,8 @@ def test_sklearn_array_api_linear_discriminant(self): new_x = EagerNumpyTensor(X) self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x)) with config_context(array_api_dispatch=True): + # It fails if scikit-learn <= 1.2.2 because the ArrayAPI + # is not strictly applied. got = ana.predict(new_x) self.assertEqualArray(expected, got.numpy()) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index aa1a59b..defe983 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -36,7 +36,7 @@ jobs: python -m pip install . -v -v -v displayName: 'install wheel' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' - task: PublishPipelineArtifact@0 inputs: @@ -87,9 +87,56 @@ jobs: black --diff . displayName: 'Black' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' +- job: 'TestLinuxArrayApi' + pool: + vmImage: 'ubuntu-latest' + strategy: + matrix: + Python310-Linux: + python.version: '3.11' + maxParallel: 3 + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '$(python.version)' + architecture: 'x64' + - script: sudo apt-get update + displayName: 'AptGet Update' + - script: python -m pip install --upgrade pip setuptools wheel + displayName: 'Install tools' + - script: pip install -r requirements.txt + displayName: 'Install Requirements' + - script: python setup.py install + displayName: 'Install onnx_array_api' + - script: | + git clone https://github.com/data-apis/array-api-tests.git + displayName: 'clone array-api-tests' + - script: | + cd array-api-tests + git submodule update --init --recursive + cd .. + displayName: 'get submodules for array-api-tests' + - script: pip install -r array-api-tests/requirements.txt + displayName: 'Install Requirements dev' + - script: | + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy + cd array-api-tests + displayName: 'Set API' + - script: | + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy + cd array-api-tests + python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros + displayName: "test_creation_functions.py::test_zeros" + #- script: | + # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy + # cd array-api-tests + # python -m pytest -x array_api_tests + # displayName: "all tests" + - job: 'TestLinux' pool: vmImage: 'ubuntu-latest' @@ -130,7 +177,7 @@ jobs: black --diff . displayName: 'Black' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' - script: | python -u setup.py bdist_wheel @@ -166,7 +213,7 @@ jobs: - script: pip install onnxmltools --no-deps displayName: 'Install onnxmltools' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' - script: | python -u setup.py bdist_wheel @@ -216,7 +263,7 @@ jobs: - script: pip install onnxmltools --no-deps displayName: 'Install onnxmltools' - script: | - python -m pytest -v -v + python -m pytest displayName: 'Runs Unit Tests' - script: | python -u setup.py bdist_wheel diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py new file mode 100644 index 0000000..e13b184 --- /dev/null +++ b/onnx_array_api/array_api/__init__.py @@ -0,0 +1,19 @@ +from onnx import TensorProto +from ..npx.npx_types import DType + + +def _finalize_array_api(module): + module.float16 = DType(TensorProto.FLOAT16) + module.float32 = DType(TensorProto.FLOAT) + module.float64 = DType(TensorProto.DOUBLE) + module.int8 = DType(TensorProto.INT8) + module.int16 = DType(TensorProto.INT16) + module.int32 = DType(TensorProto.INT32) + module.int64 = DType(TensorProto.INT64) + module.uint8 = DType(TensorProto.UINT8) + module.uint16 = DType(TensorProto.UINT16) + module.uint32 = DType(TensorProto.UINT32) + module.uint64 = DType(TensorProto.UINT64) + module.bfloat16 = DType(TensorProto.BFLOAT16) + setattr(module, "bool", DType(TensorProto.BOOL)) + setattr(module, "str", DType(TensorProto.STRING)) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py new file mode 100644 index 0000000..8d136c4 --- /dev/null +++ b/onnx_array_api/array_api/_onnx_common.py @@ -0,0 +1,50 @@ +from typing import Any, Optional +import numpy as np +from ..npx.npx_types import DType +from ..npx.npx_array_api import BaseArrayApi +from ..npx.npx_functions import ( + copy as copy_inline, +) + + +def template_asarray( + TEagerTensor: type, + a: Any, + dtype: Optional[DType] = None, + order: Optional[str] = None, + like: Any = None, + copy: bool = False, +) -> Any: + """ + Converts anything into an array. + """ + if order not in ("C", None): + raise NotImplementedError(f"asarray is not implemented for order={order!r}.") + if like is not None: + raise NotImplementedError( + f"asarray is not implemented for like != None (type={type(like)})." + ) + if isinstance(a, BaseArrayApi): + if copy: + if dtype is None: + return copy_inline(a) + return copy_inline(a).astype(dtype=dtype) + if dtype is None: + return a + return a.astype(dtype=dtype) + + if isinstance(a, int): + v = TEagerTensor(np.array(a, dtype=np.int64)) + elif isinstance(a, float): + v = TEagerTensor(np.array(a, dtype=np.float32)) + elif isinstance(a, bool): + v = TEagerTensor(np.array(a, dtype=np.bool_)) + elif isinstance(a, str): + v = TEagerTensor(np.array(a, dtype=np.str_)) + else: + raise RuntimeError(f"Unexpected type {type(a)} for the first input.") + if dtype is not None: + vt = v.astype(dtype) + else: + vt = v + return vt diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py new file mode 100644 index 0000000..79b339d --- /dev/null +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -0,0 +1,70 @@ +""" +Array API valid for an :class:`EagerNumpyTensor`. +""" +from typing import Any, Optional +import numpy as np +from onnx import TensorProto +from ..npx.npx_functions import ( + all, + abs, + absolute, + astype, + equal, + isdtype, + reshape, + take, +) +from ..npx.npx_functions import zeros as generic_zeros +from ..npx.npx_numpy_tensors import EagerNumpyTensor +from ..npx.npx_types import DType, ElemType, TensorType, OptParType +from ._onnx_common import template_asarray +from . import _finalize_array_api + +__all__ = [ + "abs", + "absolute", + "all", + "asarray", + "astype", + "equal", + "isdtype", + "reshape", + "take", + "zeros", +] + + +def asarray( + a: Any, + dtype: Optional[DType] = None, + order: Optional[str] = None, + like: Any = None, + copy: bool = False, +) -> EagerNumpyTensor: + """ + Converts anything into an array. + """ + return template_asarray( + EagerNumpyTensor, a, dtype=dtype, order=order, like=like, copy=copy + ) + + +def zeros( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if isinstance(shape, tuple): + return generic_zeros( + EagerNumpyTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order + ) + return generic_zeros(shape, dtype=dtype, order=order) + + +def _finalize(): + from . import onnx_numpy + + _finalize_array_api(onnx_numpy) + + +_finalize() diff --git a/onnx_array_api/array_api/onnx_ort.py b/onnx_array_api/array_api/onnx_ort.py new file mode 100644 index 0000000..505efdf --- /dev/null +++ b/onnx_array_api/array_api/onnx_ort.py @@ -0,0 +1,54 @@ +""" +Array API valid for an :class:`EagerOrtTensor`. +""" +from typing import Optional, Any +from ..ort.ort_tensors import EagerOrtTensor +from ..npx.npx_types import DType +from ..npx.npx_functions import ( + all, + abs, + absolute, + astype, + equal, + isdtype, + reshape, + take, +) +from ._onnx_common import template_asarray +from . import _finalize_array_api + +__all__ = [ + "all", + "abs", + "absolute", + "asarray", + "astype", + "equal", + "isdtype", + "reshape", + "take", +] + + +def asarray( + a: Any, + dtype: Optional[DType] = None, + order: Optional[str] = None, + like: Any = None, + copy: bool = False, +) -> EagerOrtTensor: + """ + Converts anything into an array. + """ + return template_asarray( + EagerOrtTensor, a, dtype=dtype, order=order, like=like, copy=copy + ) + + +def _finalize(): + from . import onnx_ort + + _finalize_array_api(onnx_ort) + + +_finalize() diff --git a/onnx_array_api/npx/npx_array_api.py b/onnx_array_api/npx/npx_array_api.py index d5b2096..58968ae 100644 --- a/onnx_array_api/npx/npx_array_api.py +++ b/onnx_array_api/npx/npx_array_api.py @@ -20,15 +20,9 @@ class BaseArrayApi: def __array_namespace__(self, api_version: Optional[str] = None): """ - Returns the module holding all the available functions. + This method must be overloaded. """ - if api_version is None or api_version == "2022.12": - from onnx_array_api.npx import npx_functions - - return npx_functions - raise ValueError( - f"Unable to return an implementation for api_version={api_version!r}." - ) + raise NotImplementedError("Method '__array_namespace__' must be implemented.") def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( @@ -134,7 +128,7 @@ def T(self) -> "BaseArrayApi": return self.generic_method("T") def astype(self, dtype: Any) -> "BaseArrayApi": - return self.generic_method("astype", dtype) + return self.generic_method("astype", dtype=dtype) @property def shape(self) -> "BaseArrayApi": diff --git a/onnx_array_api/npx/npx_core_api.py b/onnx_array_api/npx/npx_core_api.py index cc3802a..05cb0bb 100644 --- a/onnx_array_api/npx/npx_core_api.py +++ b/onnx_array_api/npx/npx_core_api.py @@ -5,7 +5,7 @@ from onnx import FunctionProto, ModelProto, NodeProto from .npx_tensors import EagerTensor -from .npx_types import ElemType, OptParType, ParType, TupleType +from .npx_types import DType, ElemType, OptParType, ParType, TupleType from .npx_var import Cst, Input, ManyIdentity, Par, Var @@ -74,7 +74,7 @@ def _process_parameter(fn, sig, k, v, new_pars, inline): parent_op=(fn.__module__, fn.__name__, 0), ) return - if isinstance(v, (int, float, str, tuple)): + if isinstance(v, (int, float, str, tuple, DType)): if inline: new_pars[k] = v else: diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index f335bd0..b55cf4d 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -1,14 +1,13 @@ -from typing import Any, Optional, Tuple, Union +from typing import Optional, Tuple, Union import array_api_compat.numpy as np_array_api import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto -from onnx.helper import np_dtype_to_tensor_dtype +from onnx.helper import make_tensor, np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype from onnx.numpy_helper import from_array from .npx_constants import FUNCTION_DOMAIN from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var -from .npx_tensors import BaseArrayApi from .npx_types import ( DType, ElemType, @@ -43,6 +42,40 @@ def absolute( return var(x, op="Abs") +@npxapi_inline +def all( + x: TensorType[ElemType.bool_, "T"], + axis: Optional[TensorType[ElemType.int64, "I"]] = None, + keepdims: ParType[int] = 0, +) -> TensorType[ElemType.bool_, "T"]: + """ + See :func:`numpy.all`. + If input x is empty, the answer is True. + """ + # size = var(x, op="Size") + # empty = var(size, cst(np.array(0, dtype=np.int64)), op="Equal") + + # z = make_tensor_value_info("Z", TensorProto.BOOL, [1]) + # g1 = make_graph([make_node("Constant", [], ["Z"], value_bool=[True])], [], [z]) + + xi = var(x, op="Cast", to=TensorProto.INT64) + + if axis is None: + new_shape = cst(np.array([-1], dtype=np.int64)) + xifl = var(xi, new_shape, op="Reshape") + # in case xifl is empty, we need to add one element + one = cst(np.array([1], dtype=np.int64)) + xifl1 = var(xifl, one, op="Concat", axis=0) + red = xifl1.min(keepdims=keepdims) + else: + if isinstance(axis, int): + axis = [axis] + if isinstance(axis, (tuple, list)): + axis = cst(np.array(axis, dtype=np.int64)) + red = xi.min(axis, keepdims=keepdims) + return var(red, cst(1), op="Equal") + + @npxapi_inline def arccos(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.arccos`." @@ -159,30 +192,9 @@ def arctanh( return var(x, op="Atanh") -def asarray( - a: Any, - dtype: Any = None, - order: Optional[str] = None, - like: Any = None, - copy: bool = False, -): - """ - Converts anything into an array. - """ - if dtype is not None: - raise RuntimeError("Method 'astype' should be used to change the type.") - if order is not None: - raise NotImplementedError(f"order={order!r} not implemented.") - if isinstance(a, BaseArrayApi): - if copy: - return a.__class__(a, copy=copy) - return a - raise NotImplementedError(f"asarray not implemented for type {type(a)}.") - - @npxapi_inline def astype( - a: TensorType[ElemType.numerics, "T1"], dtype: OptParType[int] = 1 + a: TensorType[ElemType.numerics, "T1"], dtype: OptParType[DType] = 1 ) -> TensorType[ElemType.numerics, "T2"]: """ Cast an array. @@ -335,6 +347,14 @@ def einsum( return var(*x, op="Einsum", equation=equation) +@npxapi_inline +def equal( + x: TensorType[ElemType.allowed, "T"], y: TensorType[ElemType.allowed, "T"] +) -> TensorType[ElemType.bool_, "T1"]: + "See :func:`numpy.isnan`." + return var(x, y, op="Equal") + + @npxapi_inline def erf(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: "See :func:`scipy.special.erf`." @@ -382,18 +402,20 @@ def hstack( @npxapi_inline -def copy(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: +def copy(x: TensorType[ElemType.allowed, "T"]) -> TensorType[ElemType.allowed, "T"]: "Makes a copy." return var(x, op="Identity") @npxapi_inline -def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]: +def identity( + n: ParType[int], dtype: OptParType[DType] = None +) -> TensorType[ElemType.numerics, "T"]: "Makes a copy." - val = np.array([n, n], dtype=np.int64) - shape = cst(val) model = var( - shape, op="ConstantOfShape", value=from_array(np.array([0], dtype=np.int64)) + cst(np.array([n, n], dtype=np.int64)), + op="ConstantOfShape", + value=from_array(np.array([0], dtype=np.int64)), ) v = var(model, dtype=dtype, op="EyeLike") return v @@ -401,17 +423,22 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]: @npxapi_no_inline def isdtype( - dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]] + dtype: ParType[DType], kind: Union[DType, str, Tuple[Union[DType, str], ...]] ) -> bool: """ See :epkg:`BaseArrayAPI:isdtype`. This function is not converted into an onnx graph. """ + if isinstance(dtype, DType): + dti = tensor_dtype_to_np_dtype(dtype.code) + return np_array_api.isdtype(dti, kind) + if isinstance(dtype, int): + raise TypeError(f"Unexpected type {type(dtype)}.") return np_array_api.isdtype(dtype, kind) @npxapi_inline -def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T"]: +def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]: "See :func:`numpy.isnan`." return var(x, op="IsNaN") @@ -625,3 +652,23 @@ def where( ) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.where`." return var(cond, x, y, op="Where") + + +@npxapi_inline +def zeros( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if dtype is None: + dtype = DType(TensorProto.FLOAT) + return var( + shape, + value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]), + op="ConstantOfShape", + ) diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index 92c1412..ec91b91 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -38,6 +38,7 @@ rename_in_onnx_graph, ) from .npx_types import ( + DType, ElemType, OptParType, ParType, @@ -226,6 +227,8 @@ def make_node( protos.append(att) elif v.value is not None: new_kwargs[k] = v.value + elif isinstance(v, DType): + new_kwargs[k] = v.code else: new_kwargs[k] = v @@ -337,7 +340,7 @@ def _io( if tensor_type.shape is None: type_proto = TypeProto() tensor_type_proto = type_proto.tensor_type - tensor_type_proto.elem_type = tensor_type.dtypes[0].dtype + tensor_type_proto.elem_type = tensor_type.dtypes[0].dtype.code value_info_proto = ValueInfoProto() value_info_proto.name = name # tensor_type_proto.shape.dim.extend([]) @@ -348,7 +351,7 @@ def _io( # with fixed rank. This can be changed here and in methods # `make_key`. shape = [None for _ in tensor_type.shape] - info = make_tensor_value_info(name, tensor_type.dtypes[0].dtype, shape) + info = make_tensor_value_info(name, tensor_type.dtypes[0].dtype.code, shape) # check_value_info fails if the shape is left undefined check_value_info(info, self.check_context) return info diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 6b6bfca..85b52d4 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -5,7 +5,7 @@ import numpy as np from .npx_tensors import EagerTensor, JitTensor -from .npx_types import TensorType +from .npx_types import DType, TensorType from .npx_var import Cst, Input, Var logger = getLogger("onnx-array-api") @@ -131,7 +131,7 @@ def make_key(*values, **kwargs): for iv, v in enumerate(values): if isinstance(v, (Var, EagerTensor, JitTensor)): res.append(v.key) - elif isinstance(v, (int, float)): + elif isinstance(v, (int, float, DType)): res.append(v) elif isinstance(v, slice): res.append(("slice", v.start, v.stop, v.step)) @@ -153,7 +153,7 @@ def make_key(*values, **kwargs): ) if kwargs: for k, v in sorted(kwargs.items()): - if isinstance(v, (int, float, str, type)): + if isinstance(v, (int, float, str, type, DType)): res.append(k) res.append(v) elif isinstance(v, tuple): @@ -168,6 +168,8 @@ def make_key(*values, **kwargs): else: newv.append(t) res.append(tuple(newv)) + elif v is None and k in {"dtype"}: + continue else: raise TypeError( f"Type {type(v)} is not yet supported, " @@ -193,6 +195,12 @@ def to_jit(self, *values, **kwargs): constraints = {} new_kwargs = {} for i, (v, iname) in enumerate(zip(values, names)): + if i < len(annot_values) and not isinstance(annot_values[i], type): + raise TypeError( + f"annotation {i} is not a type but is {annot_values[i]!r}." + f"for function {self.f} " + f"from module {self.f.__module__!r}." + ) if isinstance(v, (EagerTensor, JitTensor)) and ( i >= len(annot_values) or issubclass(annot_values[i], TensorType) ): @@ -250,7 +258,7 @@ def to_jit(self, *values, **kwargs): kwargs = new_kwargs else: kwargs = kwargs.copy() - kwargs.update(kwargs) + kwargs.update(new_kwargs) var = self.f(*inputs, **kwargs) @@ -336,7 +344,13 @@ def jit_call(self, *values, **kwargs): self.info("+", "jit_call") if self.input_to_kwargs_ is None: # No jitting was ever called. - onx, fct = self.to_jit(*values, **kwargs) + try: + onx, fct = self.to_jit(*values, **kwargs) + except Exception as e: + raise RuntimeError( + f"ERROR with self.f={self.f}, " + f"values={values!r}, kwargs={kwargs!r}" + ) from e if self.input_to_kwargs_ is None: raise RuntimeError( f"Attribute 'input_to_kwargs_' should be set for " @@ -520,6 +534,8 @@ def _preprocess_constants(self, *args): elif isinstance(n, (int, float)): new_args.append(self.tensor_class(np.array(n))) modified = True + elif isinstance(n, DType): + new_args.append(n) elif n in (int, float): # usually used to cast new_args.append(n) @@ -554,7 +570,17 @@ def __call__(self, *args, already_eager=False, **kwargs): lambda t: t is not None and not isinstance( t, - (EagerTensor, Cst, int, float, tuple, slice, type, np.ndarray), + ( + EagerTensor, + Cst, + int, + float, + tuple, + slice, + type, + np.ndarray, + DType, + ), ), args, ) diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 3197f60..e1a0c10 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,11 +1,13 @@ -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Optional, Tuple import numpy as np from onnx import ModelProto +from onnx.helper import np_dtype_to_tensor_dtype from onnx.reference import ReferenceEvaluator +from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor -from .npx_types import TensorType +from .npx_types import DType, TensorType class NumpyTensor: @@ -24,7 +26,7 @@ class Evaluator: """ def __init__(self, tensor_class: type, input_names: List[str], onx: ModelProto): - self.ref = ReferenceEvaluator(onx) + self.ref = ReferenceEvaluator(onx, new_ops=[ConstantOfShape]) self.input_names = input_names self.tensor_class = tensor_class @@ -54,17 +56,18 @@ def __init__(self, tensor: np.ndarray): elif isinstance( tensor, ( - np.int64, + np.float16, np.float32, np.float64, + np.int64, np.int32, - np.float16, - np.int8, np.int16, - np.uint8, - np.uint16, - np.uint32, + np.int8, np.uint64, + np.uint32, + np.uint16, + np.uint8, + np.bool_, ), ): self._tensor = np.array(tensor) @@ -80,9 +83,9 @@ def numpy(self): return self._tensor @property - def dtype(self) -> Any: + def dtype(self) -> DType: "Returns the element type of this tensor." - return self._tensor.dtype + return DType(np_dtype_to_tensor_dtype(self._tensor.dtype)) @property def key(self) -> Any: @@ -171,7 +174,17 @@ class EagerNumpyTensor(NumpyTensor, EagerTensor): Defines a value for a specific backend. """ - pass + def __array_namespace__(self, api_version: Optional[str] = None): + """ + Returns the module holding all the available functions. + """ + if api_version is None or api_version == "2022.12": + from onnx_array_api.array_api import onnx_numpy + + return onnx_numpy + raise ValueError( + f"Unable to return an implementation for api_version={api_version!r}." + ) class JitNumpyTensor(NumpyTensor, JitTensor): diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/npx/npx_numpy_tensors_ops.py new file mode 100644 index 0000000..5278019 --- /dev/null +++ b/onnx_array_api/npx/npx_numpy_tensors_ops.py @@ -0,0 +1,46 @@ +import numpy as np + +from onnx.reference.op_run import OpRun + + +class ConstantOfShape(OpRun): + @staticmethod + def _process(value): + cst = value[0] if isinstance(value, np.ndarray) else value + if isinstance(cst, int): + cst = np.int64(cst) + elif isinstance(cst, float): + cst = np.float64(cst) + elif cst is None: + cst = np.float32(0) + if not isinstance( + cst, + ( + np.float16, + np.float32, + np.float64, + np.int64, + np.int32, + np.int16, + np.int8, + np.uint64, + np.uint32, + np.uint16, + np.uint8, + np.bool_, + ), + ): + raise TypeError(f"value must be a real not {type(cst)}") + return cst + + def _run(self, data, value=None): + cst = self._process(value) + try: + res = np.full(tuple(data), cst) + except TypeError as e: + raise RuntimeError( + f"Unable to create a constant of shape " + f"{data!r} with value {cst!r} " + f"(raw value={value!r})." + ) from e + return (res,) diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index 136def5..e1e4b21 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, Union import numpy as np from onnx.helper import np_dtype_to_tensor_dtype +from .npx_types import DType, ElemType, ParType, TensorType from .npx_array_api import BaseArrayApi, ArrayApiError @@ -73,7 +74,9 @@ def _getitem_impl_var(obj, index, method_name=None): return meth(obj, index) @staticmethod - def _astype_impl(x, dtype: int = None, method_name=None): + def _astype_impl( + x: TensorType[ElemType.allowed, "T1"], dtype: ParType[DType], method_name=None + ) -> TensorType[ElemType.allowed, "T2"]: # avoids circular imports. if dtype is None: raise ValueError("dtype cannot be None.") @@ -131,9 +134,11 @@ def _generic_method_operator(self, method_name, *args: Any, **kwargs: Any) -> An new_args = [] for a in args: if isinstance(a, np.ndarray): - new_args.append(self.__class__(a.astype(self.dtype))) + new_args.append(self.__class__(a.astype(self.dtype.np_dtype))) elif isinstance(a, (int, float)): - new_args.append(self.__class__(np.array([a]).astype(self.dtype))) + new_args.append( + self.__class__(np.array([a]).astype(self.dtype.np_dtype)) + ) else: new_args.append(a) @@ -179,18 +184,17 @@ def _np_dtype_to_tensor_dtype(dtype): dtype = np.dtype("float64") return np_dtype_to_tensor_dtype(dtype) - def _generic_method_astype(self, method_name, *args: Any, **kwargs: Any) -> Any: + def _generic_method_astype( + self, method_name, dtype: Union[DType, "Var"], **kwargs: Any + ) -> Any: # avoids circular imports. from .npx_jit_eager import eager_onnx from .npx_var import Var - if len(args) != 1: - raise ValueError(f"astype takes only one argument not {len(args)}.") - dtype = ( - args[0] - if isinstance(args[0], (int, Var)) - else self._np_dtype_to_tensor_dtype(args[0]) + dtype + if isinstance(dtype, (DType, Var)) + else self._np_dtype_to_tensor_dtype(dtype) ) eag = eager_onnx(EagerTensor._astype_impl, self.__class__, bypass_eager=True) res = eag(self, dtype, method_name=method_name, already_eager=True, **kwargs) diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index a38d53f..aa335bd 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -1,7 +1,8 @@ from typing import Any, Tuple, Union import numpy as np -from onnx import AttributeProto +from onnx import AttributeProto, TensorProto +from onnx.helper import np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype class WrapperType: @@ -14,9 +15,80 @@ class WrapperType: class DType(WrapperType): """ - Annotated type for dtype. + Type of the element type returned by tensors + following the :epkg:`Array API`. + + :param code: element type based on onnx definition """ + __slots__ = ["code_"] + + def __init__(self, code: int): + self.code_ = code + + def __repr__(self) -> str: + "usual" + return f"DType({self.code_})" + + def __str__(self) -> str: + "usual" + return f"DT{self.code_}" + + def __hash__(self) -> int: + return self.code_ + + @property + def code(self) -> int: + return self.code_ + + @property + def np_dtype(self) -> "np.dtype": + return tensor_dtype_to_np_dtype(self.code_) + + def __eq__(self, dt: "DType") -> bool: + "Compares two types." + if dt.__class__ is DType: + return self.code_ == dt.code_ + if isinstance(dt, (int, bool, str)): + return False + if dt is str: + return self.code_ == TensorProto.STRING + if dt is bool: + return self.code_ == TensorProto.BOOL + if dt in ElemType.numpy_map: + dti = ElemType.numpy_map[dt] + return self.code_ == dti.code_ + try: + dti = np_dtype_to_tensor_dtype(dt) + except KeyError: + raise TypeError(f"dt must be DType not {type(dt)} - {dt!r}.") + return self.code_ == dti + + def __lt__(self, dt: "DType") -> bool: + "Compares two types." + if dt.__class__ is DType: + return self.code_ < dt.code_ + if isinstance(dt, int): + raise TypeError(f"dt must be DType not {type(dt)}.") + try: + dti = np_dtype_to_tensor_dtype(dt) + except KeyError: + raise TypeError(f"dt must be DType not {type(dt)} - {dt}.") + return self.code_ < dti + + @classmethod + def type_name(cls) -> str: + "Returns its full name." + raise NotImplementedError() + + +class _DType2(DType): + "Wraps an into a different type." + pass + + +class _DTypes(DType): + "Wraps an into a different type." pass @@ -27,22 +99,23 @@ class ElemTypeCstInner(WrapperType): __slots__ = [] - undefined = 0 - bool_ = 9 - int8 = 3 - int16 = 5 - int32 = 6 - int64 = 7 - uint8 = 2 - uint16 = 4 - uint32 = 12 - uint64 = 13 - float16 = 10 - float32 = 1 - float64 = 11 - bfloat16 = 16 - complex64 = 14 - complex128 = 15 + undefined = DType(0) + bool_ = DType(9) + int8 = DType(3) + int16 = DType(5) + int32 = DType(6) + int64 = DType(7) + uint8 = DType(2) + uint16 = DType(4) + uint32 = DType(12) + uint64 = DType(13) + float16 = DType(10) + float32 = DType(1) + float64 = DType(11) + bfloat16 = DType(16) + complex64 = DType(14) + complex128 = DType(15) + str_ = DType(8) class ElemTypeCstSet(ElemTypeCstInner): @@ -50,7 +123,7 @@ class ElemTypeCstSet(ElemTypeCstInner): Sets of element types. """ - allowed = set(range(1, 17)) + allowed = set(DType(i) for i in range(1, 17)) ints = { ElemTypeCstInner.int8, @@ -85,13 +158,15 @@ class ElemTypeCstSet(ElemTypeCstInner): ElemTypeCstInner.float64, } + strings = {ElemTypeCstInner.str_} + @staticmethod def combined(type_set): "Combines all types into a single integer by using power of 2." s = 0 for dt in type_set: - s += 1 << dt - return s + s += 1 << dt.code + return _DTypes(s) class ElemTypeCst(ElemTypeCstSet): @@ -99,45 +174,47 @@ class ElemTypeCst(ElemTypeCstSet): Combination of element types. """ - Undefined = 0 - Bool = 1 << ElemTypeCstInner.bool_ - Int8 = 1 << ElemTypeCstInner.int8 - Int16 = 1 << ElemTypeCstInner.int16 - Int32 = 1 << ElemTypeCstInner.int32 - Int64 = 1 << ElemTypeCstInner.int64 - UInt8 = 1 << ElemTypeCstInner.uint8 - UInt16 = 1 << ElemTypeCstInner.uint16 - UInt32 = 1 << ElemTypeCstInner.uint32 - UInt64 = 1 << ElemTypeCstInner.uint64 - BFloat16 = 1 << ElemTypeCstInner.bfloat16 - Float16 = 1 << ElemTypeCstInner.float16 - Float32 = 1 << ElemTypeCstInner.float32 - Float64 = 1 << ElemTypeCstInner.float64 - Complex64 = 1 << ElemTypeCstInner.complex64 - Complex128 = 1 << ElemTypeCstInner.complex128 + Undefined = _DType2(0) + Bool = _DType2(1 << ElemTypeCstInner.bool_.code) + Int8 = _DType2(1 << ElemTypeCstInner.int8.code) + Int16 = _DType2(1 << ElemTypeCstInner.int16.code) + Int32 = _DType2(1 << ElemTypeCstInner.int32.code) + Int64 = _DType2(1 << ElemTypeCstInner.int64.code) + UInt8 = _DType2(1 << ElemTypeCstInner.uint8.code) + UInt16 = _DType2(1 << ElemTypeCstInner.uint16.code) + UInt32 = _DType2(1 << ElemTypeCstInner.uint32.code) + UInt64 = _DType2(1 << ElemTypeCstInner.uint64.code) + BFloat16 = _DType2(1 << ElemTypeCstInner.bfloat16.code) + Float16 = _DType2(1 << ElemTypeCstInner.float16.code) + Float32 = _DType2(1 << ElemTypeCstInner.float32.code) + Float64 = _DType2(1 << ElemTypeCstInner.float64.code) + Complex64 = _DType2(1 << ElemTypeCstInner.complex64.code) + Complex128 = _DType2(1 << ElemTypeCstInner.complex128.code) + String = _DType2(1 << ElemTypeCstInner.str_.code) Numerics = ElemTypeCstSet.combined(ElemTypeCstSet.numerics) Floats = ElemTypeCstSet.combined(ElemTypeCstSet.floats) Ints = ElemTypeCstSet.combined(ElemTypeCstSet.ints) + Strings = ElemTypeCstSet.combined(ElemTypeCstSet.strings) class ElemType(ElemTypeCst): """ Allowed element type based on numpy dtypes. - :param dtype: integer or a string + :param dtype: DType or a string """ names_int = { att: getattr(ElemTypeCstInner, att) for att in dir(ElemTypeCstInner) - if isinstance(getattr(ElemTypeCstInner, att), int) + if isinstance(getattr(ElemTypeCstInner, att), DType) } int_names = { getattr(ElemTypeCstInner, att): att for att in dir(ElemTypeCstInner) - if isinstance(getattr(ElemTypeCstInner, att), int) + if isinstance(getattr(ElemTypeCstInner, att), DType) } set_names = { @@ -150,24 +227,24 @@ class ElemType(ElemTypeCst): **{ getattr(np, att): getattr(ElemTypeCst, att) for att in dir(ElemTypeCst) - if isinstance(getattr(ElemTypeCst, att), int) and hasattr(np, att) + if isinstance(getattr(ElemTypeCst, att), DType) and hasattr(np, att) }, **{ np.dtype(att): getattr(ElemTypeCst, att) for att in dir(ElemTypeCst) - if isinstance(getattr(ElemTypeCst, att), int) and hasattr(np, att) + if isinstance(getattr(ElemTypeCst, att), DType) and hasattr(np, att) }, } __slots__ = ["dtype"] @classmethod - def __class_getitem__(cls, dtype: Union[str, int]): + def __class_getitem__(cls, dtype: Union[str, DType]): if isinstance(dtype, str): dtype = ElemType.names_int[dtype] elif dtype in ElemType.numpy_map: dtype = ElemType.numpy_map[dtype] - elif dtype == 0: + elif dtype == DType(0): pass elif dtype not in ElemType.allowed: raise ValueError(f"Unexpected dtype {dtype} not in {ElemType.allowed}.") @@ -197,7 +274,10 @@ def get_set_name(cls, dtypes): tt.append(dt.dtype) dtypes = set(tt) for d in dir(cls): - if dtypes == getattr(cls, d): + att = getattr(cls, d) + if not isinstance(att, set): + continue + if dtypes == att: return d return None @@ -210,7 +290,7 @@ class ParType(WrapperType): :param optional: is optional or not """ - map_names = {int: "int", float: "float", str: "str"} + map_names = {int: "int", float: "float", str: "str", DType: "DType"} @classmethod def __class_getitem__(cls, dtype): @@ -333,7 +413,7 @@ def __class_getitem__(cls, *args): if isinstance(a, tuple): shape = a continue - if isinstance(a, int): + if isinstance(a, DType): if dtypes is not None: raise TypeError(f"Unexpected type {type(a)} in {args}.") dtypes = (a,) @@ -363,7 +443,7 @@ def __class_getitem__(cls, *args): check.append(dt) elif dt in ElemType.allowed: check.append(ElemType[dt]) - elif isinstance(dt, int): + elif isinstance(dt, DType): check.append(ElemType[dt]) else: raise TypeError(f"Unexpected type {type(dt)} in {dtypes}, args={args}.") diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index c67e0ff..ae5b732 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -6,7 +6,7 @@ from .npx_array_api import BaseArrayApi, ArrayApiError from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN -from .npx_types import ElemType, OptParType, ParType, TensorType, TupleType +from .npx_types import DType, ElemType, OptParType, ParType, TensorType, TupleType class Par: @@ -276,7 +276,7 @@ def __init__( op: Union[ Callable, str, Tuple[str, str], FunctionProto, ModelProto, NodeProto ] = None, - dtype: type = None, + dtype: Union[type, DType] = None, inline: bool = False, n_var_outputs: Optional[int] = 1, input_indices: Optional[List[int]] = None, @@ -298,11 +298,11 @@ def __init__( self.onnx_op_kwargs = kwargs self._prefix = None - if hasattr(dtype, "type_name"): - self.dtype = dtype - elif isinstance(dtype, int): + if isinstance(dtype, DType): # regular parameter self.onnx_op_kwargs["dtype"] = dtype + elif hasattr(dtype, "type_name"): + self.dtype = dtype elif dtype is None: self.dtype = None else: diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index 63bc378..ead834d 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -3,7 +3,6 @@ import numpy as np from onnx import ModelProto, TensorProto from onnx.defs import onnx_opset_version -from onnx.helper import tensor_dtype_to_np_dtype from onnxruntime import InferenceSession, RunOptions, get_available_providers from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice from onnxruntime.capi._pybind_state import OrtMemType @@ -11,7 +10,7 @@ from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument from ..npx.npx_tensors import EagerTensor, JitTensor -from ..npx.npx_types import TensorType +from ..npx.npx_types import DType, TensorType class OrtTensor: @@ -152,9 +151,9 @@ def shape(self) -> Tuple[int, ...]: return self._tensor.shape() @property - def dtype(self) -> Any: + def dtype(self) -> DType: "Returns the element type of this tensor." - return tensor_dtype_to_np_dtype(self._tensor.element_type()) + return DType(self._tensor.element_type()) @property def key(self) -> Any: @@ -234,7 +233,17 @@ class EagerOrtTensor(OrtTensor, OrtCommon, EagerTensor): Defines a value for :epkg:`onnxruntime` as a backend. """ - pass + def __array_namespace__(self, api_version: Optional[str] = None): + """ + Returns the module holding all the available functions. + """ + if api_version is None or api_version == "2022.12": + from onnx_array_api.array_api import onnx_ort + + return onnx_ort + raise ValueError( + f"Unable to return an implementation for api_version={api_version!r}." + ) class JitOrtTensor(OrtTensor, OrtCommon, JitTensor): diff --git a/onnx_array_api/plotting/_helper.py b/onnx_array_api/plotting/_helper.py index 69ea987..48e65d9 100644 --- a/onnx_array_api/plotting/_helper.py +++ b/onnx_array_api/plotting/_helper.py @@ -11,6 +11,7 @@ ) from onnx.helper import tensor_dtype_to_np_dtype from onnx.numpy_helper import to_array +from ..npx.npx_types import DType class Graph: @@ -44,7 +45,7 @@ def __init__( self.shape = shape @property - def dtype(self) -> Any: + def dtype(self) -> DType: return self.values.dtype diff --git a/pyproject.toml b/pyproject.toml index 832a027..9ef84cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ report_level = "INFO" ignore_directives = [ "autoclass", "autofunction", + "automodule", "gdot", "image-sg", "runpython", @@ -29,10 +30,13 @@ max-complexity = 10 [tool.ruff.per-file-ignores] "_doc/examples/plot_first_example.py" = ["E402", "F811"] "_doc/examples/plot_onnxruntime.py" = ["E402", "F811"] -"onnx_array_api/profiling.py" = ["E731"] +"onnx_array_api/array_api/onnx_numpy.py" = ["F821"] +"onnx_array_api/array_api/onnx_ort.py" = ["F821"] "onnx_array_api/npx/__init__.py" = ["F401", "F403"] "onnx_array_api/npx/npx_functions.py" = ["F821"] "onnx_array_api/npx/npx_functions_test.py" = ["F821"] +"onnx_array_api/npx/npx_tensors.py" = ["F821"] "onnx_array_api/npx/npx_var.py" = ["F821"] +"onnx_array_api/profiling.py" = ["E731"] "_unittests/ut_npx/test_npx.py" = ["F821"]