From 791eb37fa9fc018abb9a89c4464bec65f5e9516d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 10:46:25 +0200 Subject: [PATCH 1/5] Supports function full for the Array API --- _unittests/onnx-numpy-skips.txt | 2 +- _unittests/test_array_api.sh | 2 +- _unittests/ut_array_api/test_onnx_numpy.py | 10 +++++++ onnx_array_api/array_api/onnx_numpy.py | 30 ++++++++++++++++++++- onnx_array_api/npx/npx_functions.py | 30 ++++++++++++++++++++- onnx_array_api/npx/npx_graph_builder.py | 2 +- onnx_array_api/npx/npx_jit_eager.py | 14 +++++----- onnx_array_api/npx/npx_numpy_tensors_ops.py | 2 ++ onnx_array_api/npx/npx_tensors.py | 2 +- onnx_array_api/npx/npx_types.py | 28 ++++++++++++++++--- onnx_array_api/npx/npx_var.py | 16 ++++++----- 11 files changed, 117 insertions(+), 21 deletions(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index eef3e70..3beafc6 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -5,7 +5,7 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_creation_functions.py::test_eye -array_api_tests/test_creation_functions.py::test_full +# array_api_tests/test_creation_functions.py::test_full array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index cb32fe4..c75a61b 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -# pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1 +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_full || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index bd79ecf..100ed2a 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -19,6 +19,16 @@ def test_zeros(self): a = xp.absolute(mat) self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + def test_full(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.full(c, fill_value=5, dtype=xp.int64) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + a = xp.absolute(mat) + self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + if __name__ == "__main__": + TestOnnxNumpy().test_full() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 2cd4bfd..4825bd6 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -16,10 +16,11 @@ reshape, take, ) +from ..npx.npx_functions import full as generic_full from ..npx.npx_functions import ones as generic_ones from ..npx.npx_functions import zeros as generic_zeros from ..npx.npx_numpy_tensors import EagerNumpyTensor -from ..npx.npx_types import DType, ElemType, TensorType, OptParType +from ..npx.npx_types import DType, ElemType, TensorType, OptParType, ParType, Scalar from ._onnx_common import template_asarray from . import _finalize_array_api @@ -31,6 +32,7 @@ "astype", "empty", "equal", + "full", "isdtype", "isfinite", "isnan", @@ -103,6 +105,32 @@ def zeros( return generic_zeros(shape, dtype=dtype, order=order) +def full( + shape: TensorType[ElemType.int64, "I", (None,)], + fill_value: ParType[Scalar] = None, + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if fill_value is None: + raise AttributeError("fill_value cannot be None") + value = fill_value + if isinstance(shape, tuple): + return generic_full( + EagerNumpyTensor(np.array(shape, dtype=np.int64)), + fill_value=value, + dtype=dtype, + order=order, + ) + if isinstance(shape, int): + return generic_full( + EagerNumpyTensor(np.array([shape], dtype=np.int64)), + fill_value=value, + dtype=dtype, + order=order, + ) + return generic_full(shape, fill_value=value, dtype=dtype, order=order) + + def _finalize(): """ Adds common attributes to Array API defined in this modules diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 29a4481..c223f0d 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -15,6 +15,7 @@ SequenceType, TensorType, TupleType, + Scalar, ) from .npx_var import Var @@ -22,7 +23,7 @@ def _cstv(x): if isinstance(x, Var): return x - if isinstance(x, (int, float, np.ndarray)): + if isinstance(x, (int, float, bool, np.ndarray)): return cst(x) raise TypeError(f"Unexpected constant type {type(x)}.") @@ -376,6 +377,33 @@ def expit(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics return var(x, op="Sigmoid") +@npxapi_inline +def full( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + fill_value: ParType[Scalar] = None, + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if fill_value is None: + raise AttributeError("fill_value cannot be None.") + if dtype is None: + dtype = DType(TensorProto.FLOAT) + if isinstance(fill_value, (float, int, bool)): + value = make_tensor( + name="cst", data_type=dtype.code, dims=[1], vals=[fill_value] + ) + else: + raise NotImplementedError( + f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}." + ) + return var(shape, value=value, op="ConstantOfShape") + + @npxapi_inline def floor(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.floor`." diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index d41b91c..ff02843 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -784,7 +784,7 @@ def to_onnx( node_inputs.append(input_name) continue - if isinstance(i, (int, float)): + if isinstance(i, (int, float, bool)): ni = np.array(i) c = Cst(ni) input_name = self._unique(var._prefix) diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 35ff9af..5f30d30 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -131,7 +131,7 @@ def make_key(*values, **kwargs): for iv, v in enumerate(values): if isinstance(v, (Var, EagerTensor, JitTensor)): res.append(v.key) - elif isinstance(v, (int, float, DType)): + elif isinstance(v, (int, float, bool, DType)): res.append(v) elif isinstance(v, slice): res.append(("slice", v.start, v.stop, v.step)) @@ -153,7 +153,7 @@ def make_key(*values, **kwargs): ) if kwargs: for k, v in sorted(kwargs.items()): - if isinstance(v, (int, float, str, type, DType)): + if isinstance(v, (int, float, str, type, bool, DType)): res.append(k) res.append(v) elif isinstance(v, tuple): @@ -543,12 +543,12 @@ def _preprocess_constants(self, *args): elif isinstance(n, np.ndarray): new_args.append(self.tensor_class(n)) modified = True - elif isinstance(n, (int, float)): + elif isinstance(n, (int, float, bool)): new_args.append(self.tensor_class(np.array(n))) modified = True elif isinstance(n, DType): new_args.append(n) - elif n in (int, float): + elif n in (int, float, bool): # usually used to cast new_args.append(n) elif n is None: @@ -586,6 +586,7 @@ def __call__(self, *args, already_eager=False, **kwargs): EagerTensor, Cst, int, + bool, float, tuple, slice, @@ -616,12 +617,13 @@ def __call__(self, *args, already_eager=False, **kwargs): else: # tries to call the version try: - res = self.f(*values) + res = self.f(*values, **kwargs) except (AttributeError, TypeError) as e: inp1 = ", ".join(map(str, map(type, args))) inp2 = ", ".join(map(str, map(type, values))) raise TypeError( - f"Unexpected types, input types are {inp1} " f"and {inp2}." + f"Unexpected types, input types are {inp1} " + f"and {inp2}, kwargs={kwargs}." ) from e if isinstance(res, EagerTensor) or ( diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/npx/npx_numpy_tensors_ops.py index 5278019..c9cae2f 100644 --- a/onnx_array_api/npx/npx_numpy_tensors_ops.py +++ b/onnx_array_api/npx/npx_numpy_tensors_ops.py @@ -11,6 +11,8 @@ def _process(value): cst = np.int64(cst) elif isinstance(cst, float): cst = np.float64(cst) + elif isinstance(cst, bool): + cst = np.bool_(cst) elif cst is None: cst = np.float32(0) if not isinstance( diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index b0e92c2..9286ae2 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -133,7 +133,7 @@ def _generic_method_operator(self, method_name, *args: Any, **kwargs: Any) -> An for a in args: if isinstance(a, np.ndarray): new_args.append(self.__class__(a.astype(self.dtype.np_dtype))) - elif isinstance(a, (int, float)): + elif isinstance(a, (int, float, bool)): new_args.append( self.__class__(np.array([a]).astype(self.dtype.np_dtype)) ) diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index 6063e64..0f7f6dc 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -292,6 +292,19 @@ def get_set_name(cls, dtypes): return None +class Scalar: + """ + Defines a scalar. + """ + + def __init__(self, value: Union[float, int, bool]): + self.value = value + + def __repr__(self): + "usual" + return f"Scalar({self.value!r})" + + class ParType(WrapperType): """ Defines a parameter type. @@ -300,11 +313,18 @@ class ParType(WrapperType): :param optional: is optional or not """ - map_names = {int: "int", float: "float", str: "str", DType: "DType"} + map_names = { + int: "int", + float: "float", + str: "str", + DType: "DType", + bool: "bool", + Scalar: "Scalar", + } @classmethod def __class_getitem__(cls, dtype): - if isinstance(dtype, (int, float)): + if isinstance(dtype, (int, float, bool)): msg = str(dtype) else: msg = getattr(dtype, "__name__", str(dtype)) @@ -331,6 +351,8 @@ def onnx_type(cls): return AttributeProto.INT if cls.dtype == float: return AttributeProto.FLOAT + if cls.dtype == bool: + return AttributeProto.BOOL if cls.dtype == str: return AttributeProto.STRING raise RuntimeError( @@ -347,7 +369,7 @@ class OptParType(ParType): @classmethod def __class_getitem__(cls, dtype): - if isinstance(dtype, (int, float)): + if isinstance(dtype, (int, float, bool)): msg = str(dtype) else: msg = dtype.__name__ diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 2759f4c..3f5e090 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -12,7 +12,7 @@ class Par: Defines a named parameter. :param name: parameter name - :param dtype: parameter type (int, str, float) + :param dtype: parameter type (bool, int, str, float) :param value: value of the parameter if known :param parent_op: node type it belongs to """ @@ -233,7 +233,7 @@ def __call__(self, new_values): def _setitem1_where(self, index, new_values): cst, var = Var.get_cst_var() - if isinstance(new_values, (int, float)): + if isinstance(new_values, (int, float, bool)): new_values = np.array(new_values) if isinstance(new_values, np.ndarray): value = var(cst(new_values), self.parent, op="CastLike") @@ -446,7 +446,7 @@ def _get_vars(self): cst = Var.get_cst_var()[0] replacement_cst[id(i)] = cst(i) continue - if isinstance(i, (int, float)): + if isinstance(i, (int, float, bool)): cst = Var.get_cst_var()[0] replacement_cst[id(i)] = cst(np.array(i)) continue @@ -595,13 +595,13 @@ def __iter__(self): def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var": var = Var.get_cst_var()[1] - if isinstance(ov, (int, float, np.ndarray, Cst)): + if isinstance(ov, (int, float, bool, np.ndarray, Cst)): return var(self.self_var, var(ov, self.self_var, op="CastLike"), op=op_name) return var(self.self_var, ov, op=op_name, **kwargs) def _binary_op_right(self, ov: "Var", op_name: str, **kwargs) -> "Var": var = Var.get_cst_var()[1] - if isinstance(ov, (int, float, np.ndarray, Cst)): + if isinstance(ov, (int, float, bool, np.ndarray, Cst)): return var(var(ov, self.self_var, op="CastLike"), self.self_var, op=op_name) return var(ov, self.self_var, op=op_name, **kwargs) @@ -1112,10 +1112,14 @@ def __init__(self, cst: Any): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif isinstance(cst, float): Var.__init__(self, np.array(cst, dtype=np.float32), op="Identity") + elif isinstance(cst, bool): + Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") elif isinstance(cst, list): if all(map(lambda t: isinstance(t, int), cst)): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") - elif all(map(lambda t: isinstance(t, (float, int)), cst)): + elif all(map(lambda t: isinstance(t, bool), cst)): + Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") + elif all(map(lambda t: isinstance(t, (float, int, bool)), cst)): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") else: raise ValueError( From 7796b45bf67fe2ad4fa9f111ac60f73046e1ddd0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:18:31 +0200 Subject: [PATCH 2/5] improvments --- _unittests/ut_array_api/test_onnx_numpy.py | 10 +++++++++- onnx_array_api/array_api/_onnx_common.py | 2 +- onnx_array_api/array_api/onnx_numpy.py | 4 ++-- onnx_array_api/npx/npx_core_api.py | 2 +- onnx_array_api/npx/npx_functions.py | 17 +++++++++++++---- onnx_array_api/npx/npx_jit_eager.py | 1 + onnx_array_api/npx/npx_numpy_tensors_ops.py | 6 +++--- onnx_array_api/npx/npx_var.py | 12 ++++++------ 8 files changed, 36 insertions(+), 18 deletions(-) diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 100ed2a..55b2d94 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -28,7 +28,15 @@ def test_full(self): a = xp.absolute(mat) self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + def test_full_bool(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.full(c, fill_value=False) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + self.assertEqualArray(matnp, np.full((4, 5), False)) + if __name__ == "__main__": - TestOnnxNumpy().test_full() + TestOnnxNumpy().test_full_bool() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 6553137..f832b72 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -44,7 +44,7 @@ def template_asarray( except OverflowError: v = TEagerTensor(np.asarray(a, dtype=np.uint64)) elif isinstance(a, float): - v = TEagerTensor(np.array(a, dtype=np.float32)) + v = TEagerTensor(np.array(a, dtype=np.float64)) elif isinstance(a, bool): v = TEagerTensor(np.array(a, dtype=np.bool_)) elif isinstance(a, str): diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 4825bd6..9f50d3f 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -108,11 +108,11 @@ def zeros( def full( shape: TensorType[ElemType.int64, "I", (None,)], fill_value: ParType[Scalar] = None, - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: if fill_value is None: - raise AttributeError("fill_value cannot be None") + raise TypeError("fill_value cannot be None") value = fill_value if isinstance(shape, tuple): return generic_full( diff --git a/onnx_array_api/npx/npx_core_api.py b/onnx_array_api/npx/npx_core_api.py index 05cb0bb..548a40a 100644 --- a/onnx_array_api/npx/npx_core_api.py +++ b/onnx_array_api/npx/npx_core_api.py @@ -169,7 +169,7 @@ def wrapper(*inputs, **kwargs): new_inputs.append(i) elif isinstance(i, (int, float)): new_inputs.append( - np.array([i], dtype=np.int64 if isinstance(i, int) else np.float32) + np.array([i], dtype=np.int64 if isinstance(i, int) else np.float64) ) elif isinstance(i, str): new_inputs.append(Input(i)) diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index c223f0d..ab923b7 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -380,19 +380,28 @@ def expit(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics @npxapi_inline def full( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, fill_value: ParType[Scalar] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: """ - Implements :func:`numpy.zeros`. + Implements :func:`numpy.full`. """ if order != "C": raise RuntimeError(f"order={order!r} != 'C' not supported.") if fill_value is None: - raise AttributeError("fill_value cannot be None.") + raise TypeError("fill_value cannot be None.") if dtype is None: - dtype = DType(TensorProto.FLOAT) + if isinstance(fill_value, bool): + dtype = DType(TensorProto.BOOL) + elif isinstance(fill_value, int): + dtype = DType(TensorProto.INT64) + elif isinstance(fill_value, float): + dtype = DType(TensorProto.DOUBLE) + else: + raise TypeError( + f"Unexpected type {type(fill_value)} for fill_value={fill_value!r}." + ) if isinstance(fill_value, (float, int, bool)): value = make_tensor( name="cst", data_type=dtype.code, dims=[1], vals=[fill_value] diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 5f30d30..bfb87fe 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -155,6 +155,7 @@ def make_key(*values, **kwargs): for k, v in sorted(kwargs.items()): if isinstance(v, (int, float, str, type, bool, DType)): res.append(k) + res.append(type(v)) res.append(v) elif isinstance(v, tuple): newv = [] diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/npx/npx_numpy_tensors_ops.py index c9cae2f..b4639ae 100644 --- a/onnx_array_api/npx/npx_numpy_tensors_ops.py +++ b/onnx_array_api/npx/npx_numpy_tensors_ops.py @@ -7,12 +7,12 @@ class ConstantOfShape(OpRun): @staticmethod def _process(value): cst = value[0] if isinstance(value, np.ndarray) else value - if isinstance(cst, int): + if isinstance(cst, bool): + cst = np.bool_(cst) + elif isinstance(cst, int): cst = np.int64(cst) elif isinstance(cst, float): cst = np.float64(cst) - elif isinstance(cst, bool): - cst = np.bool_(cst) elif cst is None: cst = np.float32(0) if not isinstance( diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 3f5e090..a4802e3 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -1108,17 +1108,17 @@ class Cst(Var): def __init__(self, cst: Any): if isinstance(cst, np.ndarray): Var.__init__(self, cst, op="Identity") + elif isinstance(cst, bool): + Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") elif isinstance(cst, int): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif isinstance(cst, float): - Var.__init__(self, np.array(cst, dtype=np.float32), op="Identity") - elif isinstance(cst, bool): - Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") + Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") elif isinstance(cst, list): - if all(map(lambda t: isinstance(t, int), cst)): - Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") - elif all(map(lambda t: isinstance(t, bool), cst)): + if all(map(lambda t: isinstance(t, bool), cst)): Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") + elif all(map(lambda t: isinstance(t, (int, bool)), cst)): + Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif all(map(lambda t: isinstance(t, (float, int, bool)), cst)): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") else: From 870ba4b8990f6f73730ebfc7d15eb27c717f9b51 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:38:13 +0200 Subject: [PATCH 3/5] fix keys by adding types --- _unittests/test_array_api.sh | 2 +- _unittests/ut_array_api/test_onnx_numpy.py | 18 +++++++++++++++++- _unittests/ut_npx/test_npx.py | 5 +++-- onnx_array_api/_helpers.py | 2 +- onnx_array_api/array_api/onnx_numpy.py | 7 +++---- onnx_array_api/npx/npx_functions.py | 8 ++++---- onnx_array_api/npx/npx_jit_eager.py | 4 +++- 7 files changed, 32 insertions(+), 14 deletions(-) diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index c75a61b..1de8dfb 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_full || exit 1 +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 55b2d94..4cb7544 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -19,6 +19,22 @@ def test_zeros(self): a = xp.absolute(mat) self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + def test_zeros_none(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.zeros(c) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + self.assertEqualArray(matnp, np.zeros((4, 5))) + + def test_ones_none(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.ones(c) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + self.assertEqualArray(matnp, np.ones((4, 5))) + def test_full(self): c = EagerTensor(np.array([4, 5], dtype=np.int64)) mat = xp.full(c, fill_value=5, dtype=xp.int64) @@ -38,5 +54,5 @@ def test_full_bool(self): if __name__ == "__main__": - TestOnnxNumpy().test_full_bool() + TestOnnxNumpy().test_zeros_none() unittest.main(verbosity=2) diff --git a/_unittests/ut_npx/test_npx.py b/_unittests/ut_npx/test_npx.py index 93f2b5e..17b5863 100644 --- a/_unittests/ut_npx/test_npx.py +++ b/_unittests/ut_npx/test_npx.py @@ -710,8 +710,8 @@ def impl( keys = list(sorted(f.onxs)) self.assertIsInstance(f.onxs[keys[0]], ModelProto) k = keys[-1] - self.assertEqual(len(k), 3) - self.assertEqual(k[1:], ("axis", 0)) + self.assertEqual(len(k), 4) + self.assertEqual(k[1:], ("axis", int, 0)) def test_numpy_topk(self): f = topk(Input("X"), Input("K")) @@ -2416,6 +2416,7 @@ def compute_labels(X, centers, use_sqrt=False): (DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2), "use_sqrt", + bool, True, ) self.assertEqual(f.available_versions, [key]) diff --git a/onnx_array_api/_helpers.py b/onnx_array_api/_helpers.py index 6191c92..f9808ca 100644 --- a/onnx_array_api/_helpers.py +++ b/onnx_array_api/_helpers.py @@ -39,7 +39,7 @@ def np_dtype_to_tensor_dtype(dtype: Any): elif dtype is int: dt = TensorProto.INT64 elif dtype is float: - dt = TensorProto.FLOAT64 + dt = TensorProto.DOUBLE else: raise KeyError(f"Unable to guess type for dtype={dtype}.") return dt diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 9f50d3f..425418f 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -3,7 +3,6 @@ """ from typing import Any, Optional import numpy as np -from onnx import TensorProto from ..npx.npx_functions import ( all, abs, @@ -60,7 +59,7 @@ def asarray( def ones( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: if isinstance(shape, tuple): @@ -78,7 +77,7 @@ def ones( def empty( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: raise RuntimeError( @@ -89,7 +88,7 @@ def empty( def zeros( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: if isinstance(shape, tuple): diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index ab923b7..98e37f4 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -501,7 +501,7 @@ def matmul( @npxapi_inline def ones( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: """ @@ -510,7 +510,7 @@ def ones( if order != "C": raise RuntimeError(f"order={order!r} != 'C' not supported.") if dtype is None: - dtype = DType(TensorProto.FLOAT) + dtype = DType(TensorProto.DOUBLE) return var( shape, value=make_tensor(name="one", data_type=dtype.code, dims=[1], vals=[1]), @@ -711,7 +711,7 @@ def where( @npxapi_inline def zeros( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: """ @@ -720,7 +720,7 @@ def zeros( if order != "C": raise RuntimeError(f"order={order!r} != 'C' not supported.") if dtype is None: - dtype = DType(TensorProto.FLOAT) + dtype = DType(TensorProto.DOUBLE) return var( shape, value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]), diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index bfb87fe..c222f01 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -132,6 +132,7 @@ def make_key(*values, **kwargs): if isinstance(v, (Var, EagerTensor, JitTensor)): res.append(v.key) elif isinstance(v, (int, float, bool, DType)): + res.append(type(v)) res.append(v) elif isinstance(v, slice): res.append(("slice", v.start, v.stop, v.step)) @@ -170,7 +171,8 @@ def make_key(*values, **kwargs): newv.append(t) res.append(tuple(newv)) elif v is None and k in {"dtype"}: - continue + res.append(k) + res.append(v) else: raise TypeError( f"Type {type(v)} is not yet supported, " From 3f14ff10088d99e13f0700938d2e259a3717c4bd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:46:13 +0200 Subject: [PATCH 4/5] fix unit tests --- _unittests/ut_array_api/test_array_apis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_array_api/test_array_apis.py b/_unittests/ut_array_api/test_array_apis.py index c72700c..9a8dd7c 100644 --- a/_unittests/ut_array_api/test_array_apis.py +++ b/_unittests/ut_array_api/test_array_apis.py @@ -13,7 +13,7 @@ class TestArraysApis(ExtTestCase): def test_zeros_numpy_1(self): c = xpn.zeros(1) d = c.numpy() - self.assertEqualArray(np.array([0], dtype=np.float32), d) + self.assertEqualArray(np.array([0], dtype=np.float64), d) def test_zeros_ort_1(self): c = xpo.zeros(1) From b41c59564557b06b30d71a2db7e3cf928702ed96 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:56:17 +0200 Subject: [PATCH 5/5] ci --- _unittests/onnx-numpy-skips.txt | 1 - _unittests/test_array_api.sh | 2 +- azure-pipelines.yml | 10 +++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index 3beafc6..62de43f 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -5,7 +5,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_creation_functions.py::test_eye -# array_api_tests/test_creation_functions.py::test_full array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index 1de8dfb..9464ee6 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones || exit 1 +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ca24462..c449f2e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -48,7 +48,7 @@ jobs: vmImage: 'ubuntu-latest' strategy: matrix: - Python310-Linux: + Python311-Linux: python.version: '3.11' maxParallel: 3 @@ -96,7 +96,7 @@ jobs: strategy: matrix: Python310-Linux: - python.version: '3.11' + python.version: '3.10' maxParallel: 3 steps: @@ -149,7 +149,7 @@ jobs: vmImage: 'ubuntu-latest' strategy: matrix: - Python310-Linux: + Python311-Linux: python.version: '3.11' maxParallel: 3 @@ -202,7 +202,7 @@ jobs: vmImage: 'windows-latest' strategy: matrix: - Python310-Windows: + Python311-Windows: python.version: '3.11' maxParallel: 3 @@ -235,7 +235,7 @@ jobs: vmImage: 'macOS-latest' strategy: matrix: - Python310-Mac: + Python311-Mac: python.version: '3.11' maxParallel: 3