diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 9d8d98d..b5e9d88 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support * :pr:`22`: support OrtValue in function :func:`ort_profile` * :pr:`17`: implements ArrayAPI * :pr:`3`: fixes Array API with onnxruntime and scikit-learn diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 475fad6..a95b2f4 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -15,4 +15,5 @@ API onnx_tools ort plotting + reference tools diff --git a/_doc/api/reference.rst b/_doc/api/reference.rst new file mode 100644 index 0000000..acbf90a --- /dev/null +++ b/_doc/api/reference.rst @@ -0,0 +1,7 @@ +reference +========= + +ExtendedReferenceEvaluator +++++++++++++++++++++++++++ + +.. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index 9a04400..a3eaa47 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -9,6 +9,4 @@ array_api_tests/test_creation_functions.py::test_eye array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid -# Issue with CastLike and bfloat16 on onnx <= 1.15.0 -# array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 9e3efb7..859c802 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -1,8 +1,7 @@ import sys import unittest -from packaging.version import Version import numpy as np -from onnx import TensorProto, __version__ as onnx_ver +from onnx import TensorProto from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.array_api import onnx_numpy as xp from onnx_array_api.npx.npx_types import DType @@ -99,10 +98,6 @@ def test_arange_int00(self): expected = expected.astype(np.int64) self.assertEqualArray(matnp, expected) - @unittest.skipIf( - Version(onnx_ver) < Version("1.15.0"), - reason="Reference implementation of CastLike is bugged.", - ) def test_ones_like_uint16(self): x = EagerTensor(np.array(0, dtype=np.uint16)) y = np.ones_like(x.numpy()) diff --git a/_unittests/ut_reference/test_backend_extended_reference_evaluator.py b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py new file mode 100644 index 0000000..4bc0927 --- /dev/null +++ b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py @@ -0,0 +1,239 @@ +import os +import platform +import unittest +from typing import Any +import numpy +import onnx.backend.base +import onnx.backend.test +import onnx.shape_inference +import onnx.version_converter +from onnx import ModelProto +from onnx.backend.base import Device, DeviceType +from onnx.defs import onnx_opset_version +from onnx_array_api.reference import ExtendedReferenceEvaluator + + +class ExtendedReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep): + def __init__(self, session): + self._session = session + + def run(self, inputs, **kwargs): + if isinstance(inputs, numpy.ndarray): + inputs = [inputs] + if isinstance(inputs, list): + if len(inputs) == len(self._session.input_names): + feeds = dict(zip(self._session.input_names, inputs)) + else: + feeds = {} + pos_inputs = 0 + for inp, tshape in zip( + self._session.input_names, self._session.input_types + ): + shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) + if shape == inputs[pos_inputs].shape: + feeds[inp] = inputs[pos_inputs] + pos_inputs += 1 + if pos_inputs >= len(inputs): + break + elif isinstance(inputs, dict): + feeds = inputs + else: + raise TypeError(f"Unexpected input type {type(inputs)!r}.") + outs = self._session.run(None, feeds) + return outs + + +class ExtendedReferenceEvaluatorBackend(onnx.backend.base.Backend): + @classmethod + def is_opset_supported(cls, model): # pylint: disable=unused-argument + return True, "" + + @classmethod + def supports_device(cls, device: str) -> bool: + d = Device(device) + return d.type == DeviceType.CPU # type: ignore[no-any-return] + + @classmethod + def create_inference_session(cls, model): + return ExtendedReferenceEvaluator(model) + + @classmethod + def prepare( + cls, model: Any, device: str = "CPU", **kwargs: Any + ) -> ExtendedReferenceEvaluatorBackendRep: + # if isinstance(model, ExtendedReferenceEvaluatorBackendRep): + # return model + if isinstance(model, ExtendedReferenceEvaluator): + return ExtendedReferenceEvaluatorBackendRep(model) + if isinstance(model, (str, bytes, ModelProto)): + inf = cls.create_inference_session(model) + return cls.prepare(inf, device, **kwargs) + raise TypeError(f"Unexpected type {type(model)} for model.") + + @classmethod + def run_model(cls, model, inputs, device=None, **kwargs): + rep = cls.prepare(model, device, **kwargs) + return rep.run(inputs, **kwargs) + + @classmethod + def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): + raise NotImplementedError("Unable to run the model node by node.") + + +backend_test = onnx.backend.test.BackendTest( + ExtendedReferenceEvaluatorBackend, __name__ +) + +if os.getenv("APPVEYOR"): + backend_test.exclude("(test_vgg19|test_zfnet)") +if platform.architecture()[0] == "32bit": + backend_test.exclude("(test_vgg19|test_zfnet|test_bvlc_alexnet)") +if platform.system() == "Windows": + backend_test.exclude("test_sequence_model") + +if onnx_opset_version() < 21: + backend_test.exclude( + "(test_averagepool_2d_dilations" + "|test_if*" + "|test_loop*" + "|test_scan*" + "|test_sequence_map*" + ")" + ) + +if onnx_opset_version() < 19: + backend_test.exclude( + "(test_argm[ai][nx]_default_axis_example" + "|test_argm[ai][nx]_default_axis_random" + "|test_argm[ai][nx]_keepdims_example" + "|test_argm[ai][nx]_keepdims_random" + "|test_argm[ai][nx]_negative_axis_keepdims_example" + "|test_argm[ai][nx]_negative_axis_keepdims_random" + "|test_argm[ai][nx]_no_keepdims_example" + "|test_argm[ai][nx]_no_keepdims_random" + "|test_col2im_pads" + "|test_gru_batchwise" + "|test_gru_defaults" + "|test_gru_seq_length" + "|test_gru_with_initial_bias" + "|test_layer_normalization_2d_axis1_expanded" + "|test_layer_normalization_2d_axis_negative_1_expanded" + "|test_layer_normalization_3d_axis1_epsilon_expanded" + "|test_layer_normalization_3d_axis2_epsilon_expanded" + "|test_layer_normalization_3d_axis_negative_1_epsilon_expanded" + "|test_layer_normalization_3d_axis_negative_2_epsilon_expanded" + "|test_layer_normalization_4d_axis1_expanded" + "|test_layer_normalization_4d_axis2_expanded" + "|test_layer_normalization_4d_axis3_expanded" + "|test_layer_normalization_4d_axis_negative_1_expanded" + "|test_layer_normalization_4d_axis_negative_2_expanded" + "|test_layer_normalization_4d_axis_negative_3_expanded" + "|test_layer_normalization_default_axis_expanded" + "|test_logsoftmax_large_number_expanded" + "|test_lstm_batchwise" + "|test_lstm_defaults" + "|test_lstm_with_initial_bias" + "|test_lstm_with_peepholes" + "|test_mvn" + "|test_mvn_expanded" + "|test_softmax_large_number_expanded" + "|test_operator_reduced_mean" + "|test_operator_reduced_mean_keepdim)" + ) + +# The following tests are not supported. +backend_test.exclude( + "(test_gradient" + "|test_if_opt" + "|test_loop16_seq_none" + "|test_range_float_type_positive_delta_expanded" + "|test_range_int32_type_negative_delta_expanded" + "|test_scan_sum)" +) + +if onnx_opset_version() < 21: + # The following tests are using types not supported by NumPy. + # They could be if method to_array is extended to support custom + # types the same as the reference implementation does + # (see onnx.reference.op_run.to_array_extended). + backend_test.exclude( + "(test_cast_FLOAT_to_BFLOAT16" + "|test_cast_BFLOAT16_to_FLOAT" + "|test_cast_BFLOAT16_to_FLOAT" + "|test_castlike_BFLOAT16_to_FLOAT" + "|test_castlike_FLOAT_to_BFLOAT16" + "|test_castlike_FLOAT_to_BFLOAT16_expanded" + "|test_cast_no_saturate_" + "|_to_FLOAT8" + "|_FLOAT8" + "|test_quantizelinear_e4m3fn" + "|test_quantizelinear_e5m2" + ")" + ) + + # Disable test about float 8 + backend_test.exclude( + "(test_castlike_BFLOAT16*" + "|test_cast_BFLOAT16*" + "|test_cast_no_saturate*" + "|test_cast_FLOAT_to_FLOAT8*" + "|test_cast_FLOAT16_to_FLOAT8*" + "|test_cast_FLOAT8_to_*" + "|test_castlike_BFLOAT16*" + "|test_castlike_no_saturate*" + "|test_castlike_FLOAT_to_FLOAT8*" + "|test_castlike_FLOAT16_to_FLOAT8*" + "|test_castlike_FLOAT8_to_*" + "|test_quantizelinear_e*)" + ) + +# The following tests are too slow with the reference implementation (Conv). +backend_test.exclude( + "(test_bvlc_alexnet" + "|test_densenet121" + "|test_inception_v1" + "|test_inception_v2" + "|test_resnet50" + "|test_shufflenet" + "|test_squeezenet" + "|test_vgg19" + "|test_zfnet512)" +) + +# The following tests cannot pass because they consists in generating random number. +backend_test.exclude("(test_bernoulli)") + +if onnx_opset_version() < 21: + # The following tests fail due to a bug in the backend test comparison. + backend_test.exclude( + "(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)" + ) + + # The following tests fail due to a shape mismatch. + backend_test.exclude( + "(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)" + ) + + # The following tests fail due to a type mismatch. + backend_test.exclude("(test_eyelike_without_dtype)") + +# The following tests fail due to discrepancies (small but still higher than 1e-7). +backend_test.exclude("test_adam_multiple") # 1e-2 + + +# import all test cases at global scope to make them visible to python.unittest +globals().update(backend_test.test_cases) + +if __name__ == "__main__": + res = unittest.main(verbosity=2, exit=False) + tests_run = res.result.testsRun + errors = len(res.result.errors) + skipped = len(res.result.skipped) + unexpected_successes = len(res.result.unexpectedSuccesses) + expected_failures = len(res.result.expectedFailures) + print("---------------------------------") + print( + f"tests_run={tests_run} errors={errors} skipped={skipped} " + f"unexpected_successes={unexpected_successes} " + f"expected_failures={expected_failures}" + ) diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 80f530a..ba10d79 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,7 +1,7 @@ from typing import Any, Callable, List, Optional, Tuple import numpy as np from onnx import ModelProto, TensorProto -from onnx.reference import ReferenceEvaluator +from ..reference import ExtendedReferenceEvaluator from .._helpers import np_dtype_to_tensor_dtype from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor @@ -11,7 +11,7 @@ class NumpyTensor: """ Default backend based on - :func:`onnx.reference.ReferenceEvaluator`. + :func:`onnx_array_api.reference.ExtendedReferenceEvaluator`. :param input_names: input names :param onx: onnx model @@ -19,7 +19,7 @@ class NumpyTensor: class Evaluator: """ - Wraps class :class:`onnx.reference.ReferenceEvaluator` + Wraps class :class:`onnx_array_api.reference.ExtendedReferenceEvaluator` to have a signature closer to python function. :param tensor_class: class tensor such as :class:`NumpyTensor` @@ -35,7 +35,7 @@ def __init__( onx: ModelProto, f: Callable, ): - self.ref = ReferenceEvaluator(onx, new_ops=[ConstantOfShape]) + self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape]) self.input_names = input_names self.tensor_class = tensor_class self._f = f diff --git a/onnx_array_api/reference/__init__.py b/onnx_array_api/reference/__init__.py new file mode 100644 index 0000000..e4db27c --- /dev/null +++ b/onnx_array_api/reference/__init__.py @@ -0,0 +1 @@ +from .evaluator import ExtendedReferenceEvaluator diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py new file mode 100644 index 0000000..737b15d --- /dev/null +++ b/onnx_array_api/reference/evaluator.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, List, Optional, Union +from onnx import FunctionProto, ModelProto +from onnx.defs import get_schema +from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun +from .ops.op_cast_like import CastLike_15, CastLike_19 + + +class ExtendedReferenceEvaluator(ReferenceEvaluator): + """ + This class replaces the python implementation by custom implementation. + The Array API extends many operator to all types not supported + by the onnx specifications. The evaluator allows to test + scenarios outside what an onnx backend bound to the official onnx + operators definition could do. + + :: + + from onnx.reference import ReferenceEvaluator + from onnx.reference.c_ops import Conv + ref = ReferenceEvaluator(..., new_ops=[Conv]) + """ + + default_ops = [ + CastLike_15, + CastLike_19, + ] + + @staticmethod + def filter_ops(proto, new_ops, opsets): + if opsets is None and isinstance(proto, (ModelProto, FunctionProto)): + opsets = {d.domain: d.version for d in proto.opset_import} + best = {} + renamed = {} + for cl in new_ops: + if "_" not in cl.__name__: + continue + vers = cl.__name__.split("_") + try: + v = int(vers[-1]) + except ValueError: + # not a version + continue + if opsets is not None and v > opsets.get(cl.op_domain, 1): + continue + renamed[cl.__name__] = cl + key = cl.op_domain, "_".join(vers[:-1]) + if key not in best or best[key][0] < v: + best[key] = (v, cl) + + modified = [] + for cl in new_ops: + if cl.__name__ not in renamed: + modified.append(cl) + for k, v in best.items(): + atts = {"domain": k[0]} + bases = (v[1],) + if not hasattr(v[1], "op_schema"): + atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain) + new_cl = type(k[1], bases, atts) + modified.append(new_cl) + + new_ops = modified + return new_ops + + def __init__( + self, + proto: Any, + opsets: Optional[Dict[str, int]] = None, + functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None, + verbose: int = 0, + new_ops: Optional[List[OpRun]] = None, + **kwargs, + ): + if new_ops is None: + new_ops = ExtendedReferenceEvaluator.default_ops + else: + new_ops = new_ops.copy() + new_ops.extend(ExtendedReferenceEvaluator.default_ops) + new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets) + + ReferenceEvaluator.__init__( + self, + proto, + opsets=opsets, + functions=functions, + verbose=verbose, + new_ops=new_ops, + **kwargs, + ) diff --git a/onnx_array_api/reference/ops/__init__.py b/onnx_array_api/reference/ops/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/onnx_array_api/reference/ops/__init__.py @@ -0,0 +1 @@ + diff --git a/onnx_array_api/reference/ops/op_cast_like.py b/onnx_array_api/reference/ops/op_cast_like.py new file mode 100644 index 0000000..97cc798 --- /dev/null +++ b/onnx_array_api/reference/ops/op_cast_like.py @@ -0,0 +1,38 @@ +from onnx.helper import np_dtype_to_tensor_dtype +from onnx.onnx_pb import TensorProto +from onnx.reference.op_run import OpRun +from onnx.reference.ops.op_cast import ( + bfloat16, + cast_to, + float8e4m3fn, + float8e4m3fnuz, + float8e5m2, + float8e5m2fnuz, +) + + +def _cast_like(x, y, saturate): + if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": + # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 + to = TensorProto.BFLOAT16 + elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn": + to = TensorProto.FLOAT8E4M3FN + elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz": + to = TensorProto.FLOAT8E4M3FNUZ + elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2": + to = TensorProto.FLOAT8E5M2 + elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz": + to = TensorProto.FLOAT8E5M2FNUZ + else: + to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore + return (cast_to(x, to, saturate),) + + +class CastLike_15(OpRun): + def _run(self, x, y): # type: ignore + return _cast_like(x, y, True) + + +class CastLike_19(OpRun): + def _run(self, x, y, saturate=None): # type: ignore + return _cast_like(x, y, saturate) diff --git a/pyproject.toml b/pyproject.toml index 60043b5..7e15de0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,5 +37,6 @@ max-complexity = 10 "onnx_array_api/npx/npx_tensors.py" = ["F821"] "onnx_array_api/npx/npx_var.py" = ["F821"] "onnx_array_api/profiling.py" = ["E731"] +"onnx_array_api/reference/__init__.py" = ["F401"] "_unittests/ut_npx/test_npx.py" = ["F821"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 07fd7c3..4cc0562 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ black coverage flake8 furo -hypothesis +hypothesis<6.80.0 isort joblib lightgbm