From 012f2587f60622329f646a0dc75ca5b39be96f81 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 2 Jul 2023 13:22:04 +0200 Subject: [PATCH 1/4] Add function Eye to the Array API --- _unittests/onnx-numpy-skips.txt | 3 +-- _unittests/test_array_api.sh | 2 +- _unittests/ut_array_api/test_onnx_numpy.py | 20 ++++++++++++++++- onnx_array_api/array_api/__init__.py | 1 + onnx_array_api/array_api/_onnx_common.py | 2 ++ onnx_array_api/npx/npx_functions.py | 26 ++++++++++++++++++++++ 6 files changed, 50 insertions(+), 4 deletions(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index d0b47ab..2a09b2a 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -1,7 +1,6 @@ # API failures # see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt -array_api_tests/test_creation_functions.py::test_asarray_scalars -array_api_tests/test_creation_functions.py::test_arange +# uses __setitem__ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index 0a003c1..43301de 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1 +pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index e96e324..8fa746b 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -142,10 +142,28 @@ def test_as_array(self): self.assertEqual(r.dtype, DType(TensorProto.UINT64)) self.assertEqual(r.numpy(), 9223372036854775809) + def test_eye(self): + nr, nc = xp.asarray(4), xp.asarray(4) + expected = np.eye(nr.numpy(), nc.numpy()) + got = xp.eye(nr, nc) + self.assertEqualArray(expected, got.numpy()) + + def test_eye_nosquare(self): + nr, nc = xp.asarray(4), xp.asarray(5) + expected = np.eye(nr.numpy(), nc.numpy()) + got = xp.eye(nr, nc) + self.assertEqualArray(expected, got.numpy()) + + def test_eye_k(self): + nr = xp.asarray(4) + expected = np.eye(nr.numpy(), k=1) + got = xp.eye(nr, k=1) + self.assertEqualArray(expected, got.numpy()) + if __name__ == "__main__": # import logging # logging.basicConfig(level=logging.DEBUG) - # TestOnnxNumpy().test_as_array() + TestOnnxNumpy().test_eye() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index 6e4d712..1a305ca 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -17,6 +17,7 @@ "astype", "empty", "equal", + "eye", "full", "full_like", "isdtype", diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 7c2e59e..36dc72c 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -1,6 +1,7 @@ from typing import Any, Optional import warnings import numpy as np +from onnx import TensorProto with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -19,6 +20,7 @@ from ..npx.npx_functions import ( abs as generic_abs, arange as generic_arange, + copy as copy_inline, full as generic_full, full_like as generic_full_like, ones as generic_ones, diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 5c202f8..2c6451a 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -473,6 +473,32 @@ def expit( return var(x, op="Sigmoid") +@npxapi_inline +def eye( + n_rows: TensorType[ElemType.int64, "I"], + n_cols: OptTensorType[ElemType.int64, "I"] = None, + /, + *, + k: ParType[int] = 0, + dtype: ParType[DType] = DType(TensorProto.DOUBLE), +): + "See :func:`numpy.eye`." + if n_cols is None: + n_cols = n_rows + shape = cst(np.array([-1], dtype=np.int64)) + shape = var( + var(n_rows, shape, op="Reshape"), + var(n_cols, shape, op="Reshape"), + axis=0, + op="Concat", + ) + zero = zeros(shape, dtype=dtype) + res = var(zero, k=k, op="EyeLike") + if dtype is not None: + return var(res, to=dtype.code, op="Cast") + return res + + @npxapi_inline def full( shape: TensorType[ElemType.int64, "I", (None,)], From a169e100bc0d4260a73594b1bd862283b7c3f85e Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 2 Jul 2023 13:23:25 +0200 Subject: [PATCH 2/4] remove eye --- _unittests/onnx-numpy-skips.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index 2a09b2a..0d3ae03 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -4,6 +4,5 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like -array_api_tests/test_creation_functions.py::test_eye array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid From 57fb4c247951b515ef78b6b5fadd86cd8edbaf5a Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 2 Jul 2023 14:11:22 +0200 Subject: [PATCH 3/4] improve --- .../ut_array_api/test_hypothesis_array_api.py | 72 +++++++++++++++++++ onnx_array_api/array_api/_onnx_common.py | 19 +++++ onnx_array_api/npx/npx_functions.py | 4 +- onnx_array_api/npx/npx_jit_eager.py | 9 ++- 4 files changed, 100 insertions(+), 4 deletions(-) diff --git a/_unittests/ut_array_api/test_hypothesis_array_api.py b/_unittests/ut_array_api/test_hypothesis_array_api.py index fdf48f9..47bb38f 100644 --- a/_unittests/ut_array_api/test_hypothesis_array_api.py +++ b/_unittests/ut_array_api/test_hypothesis_array_api.py @@ -39,6 +39,7 @@ def sh(x): class TestHypothesisArraysApis(ExtTestCase): MAX_ARRAY_SIZE = 10000 + SQRT_MAX_ARRAY_SIZE = int(10000**0.5) VERSION = "2021.12" @classmethod @@ -138,9 +139,80 @@ def fctonx(x, kw): fctonx() self.assertEqual(len(args_onxp), len(args_np)) + def test_square_sizes_strategies(self): + dtypes = dict( + integer_dtypes=self.xps.integer_dtypes(), + uinteger_dtypes=self.xps.unsigned_integer_dtypes(), + floating_dtypes=self.xps.floating_dtypes(), + numeric_dtypes=self.xps.numeric_dtypes(), + boolean_dtypes=self.xps.boolean_dtypes(), + scalar_dtypes=self.xps.scalar_dtypes(), + ) + + dtypes_onnx = dict( + integer_dtypes=self.onxps.integer_dtypes(), + uinteger_dtypes=self.onxps.unsigned_integer_dtypes(), + floating_dtypes=self.onxps.floating_dtypes(), + numeric_dtypes=self.onxps.numeric_dtypes(), + boolean_dtypes=self.onxps.boolean_dtypes(), + scalar_dtypes=self.onxps.scalar_dtypes(), + ) + + for k, vnp in dtypes.items(): + vonxp = dtypes_onnx[k] + anp = self.xps.arrays(dtype=vnp, shape=shapes(self.xps)) + aonxp = self.onxps.arrays(dtype=vonxp, shape=shapes(self.onxps)) + self.assertNotEmpty(anp) + self.assertNotEmpty(aonxp) + + args_np = [] + + kws = array_api_kwargs(k=strategies.integers(), dtype=self.xps.numeric_dtypes()) + sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE) + ncs = strategies.none() | sqrt_sizes + + @given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws) + def fctnp(n_rows, n_cols, kw): + base = np.asarray(0) + e = np.eye(n_rows, n_cols) + self.assertNotEmpty(e.dtype) + self.assertIsInstance(e, base.__class__) + e = np.eye(n_rows, n_cols, **kw) + self.assertNotEmpty(e.dtype) + self.assertIsInstance(e, base.__class__) + args_np.append((n_rows, n_cols, kw)) + + fctnp() + self.assertEqual(len(args_np), 100) + + args_onxp = [] + + kws = array_api_kwargs( + k=strategies.integers(), dtype=self.onxps.numeric_dtypes() + ) + sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE) + ncs = strategies.none() | sqrt_sizes + + @given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws) + def fctonx(n_rows, n_cols, kw): + base = onxp.asarray(0) + e = onxp.eye(n_rows, n_cols) + self.assertIsInstance(e, base.__class__) + self.assertNotEmpty(e.dtype) + e = onxp.eye(n_rows, n_cols, **kw) + self.assertNotEmpty(e.dtype) + self.assertIsInstance(e, base.__class__) + args_onxp.append((n_rows, n_cols, kw)) + + fctonx() + self.assertEqual(len(args_onxp), len(args_np)) + if __name__ == "__main__": # cl = TestHypothesisArraysApis() # cl.setUpClass() # cl.test_scalar_strategies() + # import logging + + # logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 36dc72c..6f31d30 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -21,6 +21,7 @@ abs as generic_abs, arange as generic_arange, copy as copy_inline, + eye as generic_eye, full as generic_full, full_like as generic_full_like, ones as generic_ones, @@ -187,6 +188,24 @@ def full( return generic_full(shape, fill_value=value, dtype=dtype, order=order) +def eye( + TEagerTensor: type, + n_rows: TensorType[ElemType.int64, "I"], + n_cols: OptTensorType[ElemType.int64, "I"] = None, + /, + *, + k: ParType[int] = 0, + dtype: ParType[DType] = DType(TensorProto.DOUBLE), +): + if isinstance(n_rows, int): + n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64)) + if n_cols is None: + n_cols = n_rows + elif isinstance(n_cols, int): + n_cols = TEagerTensor(np.array(n_cols, dtype=np.int64)) + return generic_eye(n_rows, n_cols, k=k, dtype=dtype) + + def full_like( TEagerTensor: type, x: TensorType[ElemType.allowed, "T"], diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 2c6451a..beb22b6 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -476,15 +476,13 @@ def expit( @npxapi_inline def eye( n_rows: TensorType[ElemType.int64, "I"], - n_cols: OptTensorType[ElemType.int64, "I"] = None, + n_cols: TensorType[ElemType.int64, "I"], /, *, k: ParType[int] = 0, dtype: ParType[DType] = DType(TensorProto.DOUBLE), ): "See :func:`numpy.eye`." - if n_cols is None: - n_cols = n_rows shape = cst(np.array([-1], dtype=np.int64)) shape = var( var(n_rows, shape, op="Reshape"), diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index e06c944..71799f9 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -510,11 +510,18 @@ def jit_call(self, *values, **kwargs): from ..plotting.text_plot import onnx_simple_text_plot text = onnx_simple_text_plot(self.onxs[key]) + + def catch_len(x): + try: + return len(x) + except TypeError: + return 0 + raise RuntimeError( f"Unable to run function for key={key!r}, " f"types={[type(x) for x in values]}, " f"dtypes={[getattr(x, 'dtype', type(x)) for x in values]}, " - f"shapes={[getattr(x, 'shape', len(x)) for x in values]}, " + f"shapes={[getattr(x, 'shape', catch_len(x)) for x in values]}, " f"kwargs={kwargs}, " f"self.input_to_kwargs_={self.input_to_kwargs_}, " f"f={self.f} from module {self.f.__module__!r} " From f53e727307248fc643cba36741d1fd2f97fe5327 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 2 Jul 2023 14:46:54 +0200 Subject: [PATCH 4/4] fix overflow --- onnx_array_api/npx/npx_graph_builder.py | 10 ++++++++++ onnx_array_api/reference/evaluator.py | 4 ---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index e8e49a2..b5333b5 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -230,6 +230,11 @@ def make_node( new_kwargs[k] = v.value elif isinstance(v, DType): new_kwargs[k] = v.code + elif isinstance(v, int): + try: + new_kwargs[k] = int(np.array(v, dtype=np.int64)) + except OverflowError: + new_kwargs[k] = int(np.iinfo(np.int64).max) else: new_kwargs[k] = v @@ -246,6 +251,11 @@ def make_node( f"Unable to create node {op!r}, with inputs={inputs}, " f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}." ) from e + except ValueError as e: + raise ValueError( + f"Unable to create node {op!r}, with inputs={inputs}, " + f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}." + ) from e for p in protos: node.attribute.append(p) if attribute_protos is not None: diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index aa26127..77a9344 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -7,10 +7,6 @@ from .ops.op_cast_like import CastLike_15, CastLike_19 from .ops.op_constant_of_shape import ConstantOfShape -import onnx - -print(onnx.__file__) - logger = getLogger("onnx-array-api-eval")