diff --git a/_unittests/ut_array_api/test_hypothesis_array_api.py b/_unittests/ut_array_api/test_hypothesis_array_api.py index 8a854e0..e29af65 100644 --- a/_unittests/ut_array_api/test_hypothesis_array_api.py +++ b/_unittests/ut_array_api/test_hypothesis_array_api.py @@ -2,6 +2,7 @@ import warnings from os import getenv from functools import reduce +import numpy as np from operator import mul from hypothesis import given from onnx_array_api.ext_test_case import ExtTestCase @@ -89,24 +90,49 @@ def test_scalar_strategies(self): args_np = [] + xx = self.xps.arrays(dtype=dtypes["integer_dtypes"], shape=shapes(self.xps)) + kws = array_api_kwargs(dtype=strategies.none() | self.xps.scalar_dtypes()) + @given( - x=self.xps.arrays(dtype=dtypes["integer_dtypes"], shape=shapes(self.xps)), - kw=array_api_kwargs(dtype=strategies.none() | self.xps.scalar_dtypes()), + x=xx, + kw=kws, ) - def fct(x, kw): + def fctnp(x, kw): + asa1 = np.asarray(x) + asa2 = np.asarray(x, **kw) + self.assertEqual(asa1.shape, asa2.shape) args_np.append((x, kw)) - fct() + fctnp() self.assertEqual(len(args_np), 100) args_onxp = [] xshape = shapes(self.onxps) xx = self.onxps.arrays(dtype=dtypes_onnx["integer_dtypes"], shape=xshape) - kw = array_api_kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes()) + kws = array_api_kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes()) - @given(x=xx, kw=kw) + @given(x=xx, kw=kws) def fctonx(x, kw): + asa = np.asarray(x.numpy()) + try: + asp = onxp.asarray(x) + except Exception as e: + raise AssertionError(f"asarray fails with x={x!r}, asp={asa!r}.") from e + try: + self.assertEqualArray(asa, asp.numpy()) + except AssertionError as e: + raise AssertionError( + f"x={x!r} kw={kw!r} asa={asa!r}, asp={asp!r}" + ) from e + if kw: + try: + asp2 = onxp.asarray(x, **kw) + except Exception as e: + raise AssertionError( + f"asarray fails with x={x!r}, kw={kw!r}, asp={asa!r}." + ) from e + self.assertEqual(asp.shape, asp2.shape) args_onxp.append((x, kw)) fctonx() diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 8f71455..2a67f22 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -1,5 +1,10 @@ from typing import Any, Optional +import warnings import numpy as np + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from numpy.array_api._array_object import Array from ..npx.npx_types import ( DType, ElemType, @@ -77,6 +82,10 @@ def asarray( v = TEagerTensor(np.array(a, dtype=np.str_)) elif isinstance(a, list): v = TEagerTensor(np.array(a)) + elif isinstance(a, np.ndarray): + v = TEagerTensor(a) + elif isinstance(a, Array): + v = TEagerTensor(np.asarray(a)) else: raise RuntimeError(f"Unexpected type {type(a)} for the first input.") if dtype is not None: diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index ba10d79..cfc90f3 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Callable, List, Optional, Tuple import numpy as np from onnx import ModelProto, TensorProto @@ -221,13 +222,18 @@ def __bool__(self): if self.shape == (0,): return False if len(self.shape) != 0: - raise ValueError( - f"Conversion to bool only works for scalar, not for {self!r}." + warnings.warn( + f"Conversion to bool only works for scalar, not for {self!r}, " + f"bool(...)={bool(self._tensor)}." ) + try: + return bool(self._tensor) + except ValueError as e: + raise ValueError(f"Unable to convert {self} to bool.") from e return bool(self._tensor) def __int__(self): - "Implicit conversion to bool." + "Implicit conversion to int." if len(self.shape) != 0: raise ValueError( f"Conversion to bool only works for scalar, not for {self!r}." @@ -249,7 +255,7 @@ def __int__(self): return int(self._tensor) def __float__(self): - "Implicit conversion to bool." + "Implicit conversion to float." if len(self.shape) != 0: raise ValueError( f"Conversion to bool only works for scalar, not for {self!r}." @@ -261,11 +267,24 @@ def __float__(self): DType(TensorProto.BFLOAT16), }: raise TypeError( - f"Conversion to int only works for float scalar, " + f"Conversion to float only works for float scalar, " f"not for dtype={self.dtype}." ) return float(self._tensor) + def __iter__(self): + """ + The :epkg:`Array API` does not define this function (2022/12). + This method raises an exception with a better error message. + """ + warnings.warn( + f"Iterators are not implemented in the generic case. " + f"Every function using them cannot be converted into ONNX " + f"(tensors - {type(self)})." + ) + for row in self._tensor: + yield self.__class__(row) + class JitNumpyTensor(NumpyTensor, JitTensor): """ diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index 8c954c2..0e561cb 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -35,8 +35,9 @@ def __iter__(self): This method raises an exception with a better error message. """ raise ArrayApiError( - "Iterators are not implemented in the generic case. " - "Every function using them cannot be converted into ONNX." + f"Iterators are not implemented in the generic case. " + f"Every function using them cannot be converted into ONNX " + f"(tensors - {type(self)})." ) @staticmethod diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index f9029f8..fe7b287 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -59,12 +59,16 @@ def __eq__(self, dt: "DType") -> bool: return False if dt.__class__ is DType: return self.code_ == dt.code_ - if isinstance(dt, (int, bool, str)): + if isinstance(dt, (int, bool, str, float)): return False + if dt is int: + return self.code_ == TensorProto.INT64 if dt is str: return self.code_ == TensorProto.STRING if dt is bool: return self.code_ == TensorProto.BOOL + if dt is float: + return self.code_ == TensorProto.FLOAT64 if isinstance(dt, list): return False if dt in ElemType.numpy_map: diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 3341e46..3e77cc5 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -607,8 +607,9 @@ def __iter__(self): This method raises an exception with a better error message. """ raise ArrayApiError( - "Iterators are not implemented in the generic case. " - "Every function using them cannot be converted into ONNX." + f"Iterators are not implemented in the generic case. " + f"Every function using them cannot be converted into ONNX " + f"(Var - {type(self)})." ) def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var": diff --git a/requirements-dev.txt b/requirements-dev.txt index 4cc0562..07fd7c3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ black coverage flake8 furo -hypothesis<6.80.0 +hypothesis isort joblib lightgbm