diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index a3eaa47..1eac9e2 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -6,7 +6,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_creation_functions.py::test_eye -array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/ut_array_api/test_hypothesis_array_api.py b/_unittests/ut_array_api/test_hypothesis_array_api.py index e29af65..fdf48f9 100644 --- a/_unittests/ut_array_api/test_hypothesis_array_api.py +++ b/_unittests/ut_array_api/test_hypothesis_array_api.py @@ -140,7 +140,7 @@ def fctonx(x, kw): if __name__ == "__main__": - cl = TestHypothesisArraysApis() - cl.setUpClass() - cl.test_scalar_strategies() + # cl = TestHypothesisArraysApis() + # cl.setUpClass() + # cl.test_scalar_strategies() unittest.main(verbosity=2) diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 859c802..78f8872 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -112,7 +112,25 @@ def test_ones_like_uint16(self): expected = np.array(1, dtype=np.uint16) self.assertEqualArray(expected, z.numpy()) + def test_full_like(self): + c = EagerTensor(np.array(False)) + expected = np.full_like(c.numpy(), fill_value=False) + mat = xp.full_like(c, fill_value=False) + matnp = mat.numpy() + self.assertEqual(matnp.shape, tuple()) + self.assertEqualArray(expected, matnp) + + def test_full_like_mx(self): + c = EagerTensor(np.array([], dtype=np.uint8)) + expected = np.full_like(c.numpy(), fill_value=0) + mat = xp.full_like(c, fill_value=0) + matnp = mat.numpy() + self.assertEqualArray(expected, matnp) + if __name__ == "__main__": - # TestOnnxNumpy().test_ones_like() + # import logging + + # logging.basicConfig(level=logging.DEBUG) + # TestOnnxNumpy().test_full_like_mx() unittest.main(verbosity=2) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c449f2e..709ced3 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -246,9 +246,10 @@ jobs: architecture: 'x64' - script: gcc --version displayName: 'gcc version' - - script: | - brew update - displayName: 'brew update' + #- script: brew upgrade + # displayName: 'brew upgrade' + #- script: brew update + # displayName: 'brew update' - script: export displayName: 'export' - script: gcc --version diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index e1e09b8..bd762be 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -18,6 +18,7 @@ "empty", "equal", "full", + "full_like", "isdtype", "isfinite", "isinf", diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 2a67f22..98a89f2 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -20,6 +20,7 @@ abs as generic_abs, arange as generic_arange, full as generic_full, + full_like as generic_full_like, ones as generic_ones, zeros as generic_zeros, ) @@ -177,6 +178,23 @@ def full( return generic_full(shape, fill_value=value, dtype=dtype, order=order) +def full_like( + TEagerTensor: type, + x: TensorType[ElemType.allowed, "T"], + /, + fill_value: ParType[Scalar] = None, + *, + dtype: OptParType[DType] = None, + order: OptParType[str] = "C", +) -> EagerTensor[TensorType[ElemType.allowed, "TR"]]: + if dtype is None: + if isinstance(fill_value, TEagerTensor): + dtype = fill_value.dtype + elif isinstance(x, TEagerTensor): + dtype = x.dtype + return generic_full_like(x, fill_value=fill_value, dtype=dtype, order=order) + + def ones( TEagerTensor: type, shape: EagerTensor[TensorType[ElemType.int64, "I", (None,)]], diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index c0f0a7b..8a886b2 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -275,9 +275,9 @@ def astype( if dtype is int: to = DType(TensorProto.INT64) elif dtype is float: - to = DType(TensorProto.FLOAT64) + to = DType(TensorProto.DOUBLE) elif dtype is bool: - to = DType(TensorProto.FLOAT64) + to = DType(TensorProto.BOOL) elif dtype is str: to = DType(TensorProto.STRING) else: @@ -511,6 +511,49 @@ def full( return var(shape, value=value, op="ConstantOfShape") +@npxapi_inline +def full_like( + x: TensorType[ElemType.allowed, "T"], + /, + *, + fill_value: ParType[Scalar] = None, + dtype: OptParType[DType] = None, + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if fill_value is None: + raise TypeError("fill_value cannot be None.") + if dtype is None: + if isinstance(fill_value, bool): + dtype = DType(TensorProto.BOOL) + elif isinstance(fill_value, int): + dtype = DType(TensorProto.INT64) + elif isinstance(fill_value, float): + dtype = DType(TensorProto.DOUBLE) + else: + raise TypeError( + f"Unexpected type {type(fill_value)} for fill_value={fill_value!r} " + f"and dtype={dtype!r}." + ) + if isinstance(fill_value, (float, int, bool)): + value = make_tensor( + name="cst", data_type=dtype.code, dims=[1], vals=[fill_value] + ) + else: + raise NotImplementedError( + f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}." + ) + + v = var(x.shape, value=value, op="ConstantOfShape") + if dtype is None: + return var(v, x, op="CastLike") + return v + + @npxapi_inline def floor( x: TensorType[ElemType.numerics, "T"], / diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index b49d7ce..e06c944 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -58,6 +58,7 @@ def info( kwargs: Optional[Dict[str, Any]] = None, key: Optional[Tuple[Any, ...]] = None, onx: Optional[ModelProto] = None, + output: Optional[Any] = None, ): """ Logs a status. @@ -93,6 +94,8 @@ def info( "" if args is None else str(args), "" if kwargs is None else str(kwargs), ) + if output is not None: + logger.debug("==== [%s]", output) def status(self, me: str) -> str: """ @@ -517,7 +520,7 @@ def jit_call(self, *values, **kwargs): f"f={self.f} from module {self.f.__module__!r} " f"onnx=\n---\n{text}\n---\n{self.onxs[key]}" ) from e - self.info("-", "jit_call") + self.info("-", "jit_call", output=res) return res @@ -737,11 +740,13 @@ def __call__(self, *args, already_eager=False, **kwargs): try: res = self.f(*values, **kwargs) except (AttributeError, TypeError) as e: - inp1 = ", ".join(map(str, map(type, args))) - inp2 = ", ".join(map(str, map(type, values))) + inp1 = ", ".join(map(str, map(lambda a: type(a).__name__, args))) + inp2 = ", ".join(map(str, map(lambda a: type(a).__name__, values))) raise TypeError( - f"Unexpected types, input types are {inp1} " - f"and {inp2}, kwargs={kwargs}." + f"Unexpected types, input types are args=[{inp1}], " + f"values=[{inp2}], kwargs={kwargs}. " + f"(values = self._preprocess_constants(args)) " + f"args={args}, values={values}" ) from e if isinstance(res, EagerTensor) or ( diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index cfc90f3..a106b95 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -4,7 +4,6 @@ from onnx import ModelProto, TensorProto from ..reference import ExtendedReferenceEvaluator from .._helpers import np_dtype_to_tensor_dtype -from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor from .npx_types import DType, TensorType @@ -36,7 +35,7 @@ def __init__( onx: ModelProto, f: Callable, ): - self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape]) + self.ref = ExtendedReferenceEvaluator(onx) self.input_names = input_names self.tensor_class = tensor_class self._f = f diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index fe7b287..54cc618 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -68,7 +68,7 @@ def __eq__(self, dt: "DType") -> bool: if dt is bool: return self.code_ == TensorProto.BOOL if dt is float: - return self.code_ == TensorProto.FLOAT64 + return self.code_ == TensorProto.DOUBLE if isinstance(dt, list): return False if dt in ElemType.numpy_map: diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index 737b15d..aa26127 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -1,9 +1,18 @@ +from logging import getLogger from typing import Any, Dict, List, Optional, Union from onnx import FunctionProto, ModelProto from onnx.defs import get_schema from onnx.reference import ReferenceEvaluator from onnx.reference.op_run import OpRun from .ops.op_cast_like import CastLike_15, CastLike_19 +from .ops.op_constant_of_shape import ConstantOfShape + +import onnx + +print(onnx.__file__) + + +logger = getLogger("onnx-array-api-eval") class ExtendedReferenceEvaluator(ReferenceEvaluator): @@ -24,6 +33,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): default_ops = [ CastLike_15, CastLike_19, + ConstantOfShape, ] @staticmethod @@ -88,3 +98,10 @@ def __init__( new_ops=new_ops, **kwargs, ) + + def _log(self, level: int, pattern: str, *args: List[Any]) -> None: + if level < self.verbose: + new_args = [self._log_arg(a) for a in args] + print(pattern % tuple(new_args)) + else: + logger.debug(pattern, *args) diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/reference/ops/op_constant_of_shape.py similarity index 78% rename from onnx_array_api/npx/npx_numpy_tensors_ops.py rename to onnx_array_api/reference/ops/op_constant_of_shape.py index b4639ae..33308af 100644 --- a/onnx_array_api/npx/npx_numpy_tensors_ops.py +++ b/onnx_array_api/reference/ops/op_constant_of_shape.py @@ -1,12 +1,18 @@ import numpy as np - from onnx.reference.op_run import OpRun class ConstantOfShape(OpRun): @staticmethod def _process(value): - cst = value[0] if isinstance(value, np.ndarray) else value + cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value + if isinstance(value, np.ndarray): + if len(value.shape) == 0: + cst = value + elif value.size > 0: + cst = value.ravel()[0] + else: + raise ValueError(f"Unexpected fill_value={value!r}") if isinstance(cst, bool): cst = np.bool_(cst) elif isinstance(cst, int):