diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst new file mode 100644 index 0000000..4e5aeb5 --- /dev/null +++ b/CHANGELOGS.rst @@ -0,0 +1,7 @@ +Change Logs +=========== + +0.2.0 ++++++ + +* :pr:`3`: fixes Array API with onnxruntime diff --git a/_doc/conf.py b/_doc/conf.py index 0e51806..52f54b6 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -57,6 +57,11 @@ } epkg_dictionary = { + "Array API": "https://data-apis.org/array-api/", + "ArrayAPI": ( + "https://data-apis.org/array-api/", + ("2022.12/API_specification/generated/array_api.{0}.html", 1), + ), "DOT": "https://graphviz.org/doc/info/lang.html", "JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation", "onnx": "https://onnx.ai/onnx/", @@ -65,7 +70,7 @@ "numpy": "https://numpy.org/", "numba": "https://numba.pydata.org/", "onnx-array-api": ( - "http://www.xavierdupre.fr/app/" "onnx-array-api/helpsphinx/index.html" + "http://www.xavierdupre.fr/app/onnx-array-api/helpsphinx/index.html" ), "pyinstrument": "https://github.com/joerick/pyinstrument", "python": "https://www.python.org/", diff --git a/_doc/index.rst b/_doc/index.rst index bcd0c89..bd87c3b 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -34,6 +34,7 @@ well as to execute it. tutorial/index api/index auto_examples/index + ../CHANGELOGS Sources available on `github/onnx-array-api `_, diff --git a/_unittests/ut_npx/test_sklearn_array_api.py b/_unittests/ut_npx/test_sklearn_array_api.py index 70f434a..016a170 100644 --- a/_unittests/ut_npx/test_sklearn_array_api.py +++ b/_unittests/ut_npx/test_sklearn_array_api.py @@ -1,7 +1,8 @@ import unittest import numpy as np +from packaging.version import Version from onnx.defs import onnx_opset_version -from sklearn import config_context +from sklearn import config_context, __version__ as sklearn_version from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor @@ -10,23 +11,15 @@ DEFAULT_OPSET = onnx_opset_version() -def take(self, X, indices, *, axis): - # Overwritting method take as it is using iterators. - # When array_api supports `take` we can use this directly - # https://github.com/data-apis/array-api/issues/177 - X_np = self._namespace.take(X, indices, axis=axis) - return self._namespace.asarray(X_np) - - class TestSklearnArrayAPI(ExtTestCase): + @unittest.skipIf( + Version(sklearn_version) <= Version("1.2.2"), + reason="reshape ArrayAPI not followed", + ) def test_sklearn_array_api_linear_discriminant(self): - from sklearn.utils._array_api import _ArrayAPIWrapper - - _ArrayAPIWrapper.take = take X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) y = np.array([1, 1, 1, 2, 2, 2]) ana = LinearDiscriminantAnalysis() - ana = LinearDiscriminantAnalysis() ana.fit(X, y) expected = ana.predict(X) diff --git a/_unittests/ut_ort/test_sklearn_array_api_ort.py b/_unittests/ut_ort/test_sklearn_array_api_ort.py new file mode 100644 index 0000000..68e6725 --- /dev/null +++ b/_unittests/ut_ort/test_sklearn_array_api_ort.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +from packaging.version import Version +from onnx.defs import onnx_opset_version +from sklearn import config_context, __version__ as sklearn_version +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor + + +DEFAULT_OPSET = onnx_opset_version() + + +class TestSklearnArrayAPIOrt(ExtTestCase): + @unittest.skipIf( + Version(sklearn_version) <= Version("1.2.2"), + reason="reshape ArrayAPI not followed", + ) + def test_sklearn_array_api_linear_discriminant(self): + X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) + y = np.array([1, 1, 1, 2, 2, 2]) + ana = LinearDiscriminantAnalysis() + ana.fit(X, y) + expected = ana.predict(X) + + new_x = EagerOrtTensor(OrtTensor.from_array(X)) + self.assertEqual(new_x.device_name, "Cpu") + self.assertStartsWith( + "EagerOrtTensor(OrtTensor.from_array(array([[", repr(new_x) + ) + with config_context(array_api_dispatch=True): + got = ana.predict(new_x) + self.assertEqualArray(expected, got.numpy()) + + +if __name__ == "__main__": + # import logging + # logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 50b1795..aa1a59b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -43,6 +43,53 @@ jobs: artifactName: 'wheel-linux-wheel-$(python.version)' targetPath: 'dist' +- job: 'TestLinuxNightly' + pool: + vmImage: 'ubuntu-latest' + strategy: + matrix: + Python310-Linux: + python.version: '3.11' + maxParallel: 3 + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '$(python.version)' + architecture: 'x64' + - script: sudo apt-get update + displayName: 'AptGet Update' + - script: sudo apt-get install -y pandoc + displayName: 'Install Pandoc' + - script: sudo apt-get install -y inkscape + displayName: 'Install Inkscape' + - script: sudo apt-get install -y graphviz + displayName: 'Install Graphviz' + - script: python -m pip install --upgrade pip setuptools wheel + displayName: 'Install tools' + - script: pip install -r requirements.txt + displayName: 'Install Requirements' + - script: pip install -r requirements-dev.txt + displayName: 'Install Requirements dev' + - script: pip uninstall -y scikit-learn + displayName: 'Uninstall scikit-learn' + - script: pip install --pre --extra-index https://pypi.anaconda.org/scipy-wheels-nightly/simple scikit-learn + displayName: 'Install scikit-learn nightly' + - script: pip install onnxmltools --no-deps + displayName: 'Install onnxmltools' + - script: | + ruff . + displayName: 'Ruff' + - script: | + rstcheck -r ./_doc ./onnx_array_api + displayName: 'rstcheck' + - script: | + black --diff . + displayName: 'Black' + - script: | + python -m pytest -v + displayName: 'Runs Unit Tests' + - job: 'TestLinux' pool: vmImage: 'ubuntu-latest' diff --git a/onnx_array_api/npx/npx_array_api.py b/onnx_array_api/npx/npx_array_api.py index e614ca7..05bfe14 100644 --- a/onnx_array_api/npx/npx_array_api.py +++ b/onnx_array_api/npx/npx_array_api.py @@ -1,22 +1,34 @@ -from typing import Any +from typing import Any, Optional import numpy as np from .npx_types import OptParType, ParType, TupleType +class ArrayApiError(RuntimeError): + """ + Raised when a function is not supported by the :epkg:`Array API`. + """ + + pass + + class ArrayApi: """ List of supported method by a tensor. """ - def __array_namespace__(self): + def __array_namespace__(self, api_version: Optional[str] = None): """ Returns the module holding all the available functions. """ - from onnx_array_api.npx import npx_functions + if api_version is None or api_version == "2022.12": + from onnx_array_api.npx import npx_functions - return npx_functions + return npx_functions + raise ValueError( + f"Unable to return an implementation for api_version={api_version!r}." + ) def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( diff --git a/onnx_array_api/npx/npx_core_api.py b/onnx_array_api/npx/npx_core_api.py index 27b0dd1..cc3802a 100644 --- a/onnx_array_api/npx/npx_core_api.py +++ b/onnx_array_api/npx/npx_core_api.py @@ -252,3 +252,10 @@ def npxapi_inline(fn): to call. """ return _xapi(fn, inline=True) + + +def npxapi_no_inline(fn): + """ + Functions decorated with this decorator are not converted into ONNX. + """ + return fn diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index e10169b..fb455cc 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -1,14 +1,16 @@ from typing import Any, Optional, Tuple, Union +import array_api_compat.numpy as np_array_api import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto from onnx.helper import np_dtype_to_tensor_dtype from onnx.numpy_helper import from_array from .npx_constants import FUNCTION_DOMAIN -from .npx_core_api import cst, make_tuple, npxapi_inline, var +from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var from .npx_tensors import ArrayApi from .npx_types import ( + DType, ElemType, OptParType, ParType, @@ -397,6 +399,17 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]: return v +@npxapi_no_inline +def isdtype( + dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]] +) -> bool: + """ + See :epkg:`ArrayAPI:isdtype`. + This function is not converted into an onnx graph. + """ + return np_array_api.isdtype(dtype, kind) + + @npxapi_inline def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T"]: "See :func:`numpy.isnan`." @@ -460,9 +473,23 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, @npxapi_inline def reshape( - x: TensorType[ElemType.numerics, "T"], shape: TensorType[ElemType.int64, "I"] + x: TensorType[ElemType.numerics, "T"], + shape: TensorType[ElemType.int64, "I", (None,)], ) -> TensorType[ElemType.numerics, "T"]: - "See :func:`numpy.reshape`." + """ + See :func:`numpy.reshape`. + + .. warning:: + + Numpy definition is tricky because onnxruntime does not handle well + dimensions with an undefined number of dimensions. + However the array API defines a more stricly signature for + `reshape `_. + :epkg:`scikit-learn` updated its code to follow the Array API in + `PR 26030 ENH Forces shape to be tuple when using Array API's reshape + `_. + """ if isinstance(shape, int): shape = cst(np.array([shape], dtype=np.int64)) shape_reshaped = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape") diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index d61fdf2..92c1412 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -798,6 +798,23 @@ def to_onnx( node_inputs.append(input_name) continue + if isinstance(i, tuple) and all(map(lambda x: isinstance(x, int), i)): + ai = np.array(list(i), dtype=np.int64) + c = Cst(ai) + input_name = self._unique(var._prefix) + self._id_vars[id(i), index] = input_name + self._id_vars[id(c), index] = input_name + self.make_node( + "Constant", + [], + [input_name], + value=from_array(ai), + opset=self.target_opsets[""], + ) + self.onnx_names_[input_name] = c + node_inputs.append(input_name) + continue + raise NotImplementedError( f"Unexpected type {type(i)} for node={domop}." ) diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 9028d2f..6b6bfca 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -367,12 +367,18 @@ def jit_call(self, *values, **kwargs): try: res = fct.run(*values) except Exception as e: + from ..plotting.text_plot import onnx_simple_text_plot + + text = onnx_simple_text_plot(self.onxs[key]) raise RuntimeError( f"Unable to run function for key={key!r}, " f"types={[type(x) for x in values]}, " + f"dtypes={[x.dtype for x in values]}, " + f"shapes={[x.shape for x in values]}, " f"kwargs={kwargs}, " f"self.input_to_kwargs_={self.input_to_kwargs_}, " - f"onnx={self.onxs[key]}." + f"f={self.f} from module {self.f.__module__!r} " + f"onnx=\n---\n{text}\n---\n{self.onxs[key]}" ) from e self.info("-", "jit_call") return res @@ -492,14 +498,22 @@ def _preprocess_constants(self, *args): elif isinstance(n, Cst): new_args.append(self.tensor_class(n.inputs[0])) modified = True - # elif isinstance(n, tuple): - # if any(map(lambda t: isinstance(t, Var), n)): - # raise TypeError( - # f"Unexpected types in tuple " - # f"({[type(t) for t in n]}) for input {i}, " - # f"function {self.f} from module {self.f.__module__!r}." - # ) - # new_args.append(n) + elif isinstance(n, tuple): + if all(map(lambda x: isinstance(x, int), n)): + new_args.append( + self.tensor_class(np.array(list(n), dtype=np.int64)) + ) + elif any(map(lambda t: isinstance(t, Var), n)): + raise TypeError( + f"Unexpected types in tuple " + f"({[type(t) for t in n]}) for input {i}, " + f"function {self.f} from module {self.f.__module__!r}." + ) + else: + raise TypeError( + f"Unsupported tuple {n!r} for input {i}, " + f"function {self.f} from module {self.f.__module__!r}." + ) elif isinstance(n, np.ndarray): new_args.append(self.tensor_class(n)) modified = True diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index ad27391..3197f60 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -162,12 +162,6 @@ def get_ir_version(cls, ir_version): """ return ir_version - def const_cast(self, to: Any = None) -> "EagerTensor": - """ - Casts a constant without any ONNX conversion. - """ - return self.__class__(self._tensor.astype(to)) - # The class should support whatever Var supports. # This part is not yet complete. diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index d741d95..5863dfe 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -3,7 +3,7 @@ import numpy as np from onnx.helper import np_dtype_to_tensor_dtype -from .npx_array_api import ArrayApi +from .npx_array_api import ArrayApi, ArrayApiError class JitTensor: @@ -21,24 +21,14 @@ class EagerTensor(ArrayApi): :class:`ArrayApi`. """ - def const_cast(self, to: Any = None) -> "EagerTensor": - """ - Casts a constant without any ONNX conversion. - """ - raise NotImplementedError( - f"Method 'const_cast' must be overwritten in class " - f"{self.__class__.__name__!r}." - ) - def __iter__(self): """ - This is not implementation in the generic case. + The :epkg:`Array API` does not define this function (2022/12). This method raises an exception with a better error message. """ - raise RuntimeError( + raise ArrayApiError( "Iterators are not implemented in the generic case. " - "It may be enabled for the eager mode but it might fail " - "when a whole function is converted into ONNX." + "Every function using them cannot be converted into ONNX." ) @staticmethod @@ -141,7 +131,9 @@ def _generic_method_operator(self, method_name, *args: Any, **kwargs: Any) -> An new_args = [] for a in args: if isinstance(a, np.ndarray): - new_args.append(self.__class__(a).const_cast(self.dtype)) + new_args.append(self.__class__(a.astype(self.dtype))) + elif isinstance(a, (int, float)): + new_args.append(self.__class__(np.array([a]).astype(self.dtype))) else: new_args.append(a) diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index bc76c69..a38d53f 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -12,7 +12,15 @@ class WrapperType: pass -class ElemTypeCstInner: +class DType(WrapperType): + """ + Annotated type for dtype. + """ + + pass + + +class ElemTypeCstInner(WrapperType): """ Defines all possible types and tensor element type. """ @@ -194,7 +202,7 @@ def get_set_name(cls, dtypes): return None -class ParType: +class ParType(WrapperType): """ Defines a parameter type. @@ -293,7 +301,7 @@ def __str__(self) -> str: return f"{self.__class__.__name__}[{self.shape}]" -class TensorType: +class TensorType(WrapperType): """ Used to annotate functions. @@ -440,7 +448,7 @@ def issuperset(cls, tensor_type: type) -> bool: return True -class SequenceType: +class SequenceType(WrapperType): """ Defines a sequence of tensors. """ @@ -480,7 +488,7 @@ def type_name(cls) -> str: return newt -class TupleType: +class TupleType(WrapperType): """ Defines a sequence of tensors. """ diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index ba561f3..3b60e01 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -4,7 +4,7 @@ from onnx import FunctionProto, ModelProto, NodeProto, TensorProto from onnx.helper import np_dtype_to_tensor_dtype -from .npx_array_api import ArrayApi +from .npx_array_api import ArrayApi, ArrayApiError from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN from .npx_types import ElemType, OptParType, ParType, TensorType, TupleType @@ -57,27 +57,27 @@ def onnx_type(self): def __eq__(self, x): "Should not be used." - raise NotImplementedError() + raise NotImplementedError("__eq__ should not be used.") def __neq__(self, x): "Should not be used." - raise NotImplementedError() + raise NotImplementedError("__neq__ should not be used.") def __lt__(self, x): "Should not be used." - raise NotImplementedError() + raise NotImplementedError("__lt__ should not be used.") def __gt__(self, x): "Should not be used." - raise NotImplementedError() + raise NotImplementedError("__gt__ should not be used.") def __le__(self, x): "Should not be used." - raise NotImplementedError() + raise NotImplementedError("__le__ should not be used.") def __ge__(self, x): "Should not be used." - raise NotImplementedError() + raise NotImplementedError("__ge__ should not be used.") class ManyIdentity: @@ -443,6 +443,21 @@ def _get_vars(self): cst = Var.get_cst_var()[0] replacement_cst[id(i)] = cst(np.array(i)) continue + if isinstance(i, tuple): + if all(map(lambda x: isinstance(x, int), i)): + cst = Var.get_cst_var()[0] + replacement_cst[id(i)] = cst(np.array(list(i), dtype=np.int64)) + continue + if any(map(lambda t: isinstance(t, Var), i)): + raise TypeError( + f"Unexpected types in tuple " + f"({[type(t) for t in i]}), " + f"function {self.f} from module {self.f.__module__!r}." + ) + raise TypeError( + f"Unsupported tuple {i!r}, " + f"function {self.f} from module {self.f.__module__!r}." + ) if i is None: continue raise TypeError( @@ -563,10 +578,10 @@ def to_onnx( def __iter__(self): """ - This is not implementation in the generic case. + The :epkg:`Array API` does not define this function (2022/12). This method raises an exception with a better error message. """ - raise RuntimeError( + raise ArrayApiError( "Iterators are not implemented in the generic case. " "Every function using them cannot be converted into ONNX." ) @@ -850,10 +865,12 @@ def shape(self) -> "Var": def reshape(self, shape: "Var") -> "Var": "Reshape" - var = Var.get_cst_var()[1] + cst, var = Var.get_cst_var() if isinstance(shape, (tuple, list)): shape = np.array(shape, dtype=np.int64) + else: + shape = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape") return var(self.self_var, shape, op="Reshape") def reduce_function( diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index 0249d29..4f317a5 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -3,6 +3,7 @@ import numpy as np from onnx import ModelProto, TensorProto from onnx.defs import onnx_opset_version +from onnx.helper import tensor_dtype_to_np_dtype from onnxruntime import InferenceSession, RunOptions, get_available_providers from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice from onnxruntime.capi._pybind_state import OrtMemType @@ -123,14 +124,28 @@ def run(self, *inputs: List["OrtTensor"]) -> List["OrtTensor"]: ) return list(map(inputs[0].__class__, res)) - def __init__(self, tensor: Union[C_OrtValue, "OrtTensor"]): + def __init__(self, tensor: Union[C_OrtValue, "OrtTensor", np.ndarray]): if isinstance(tensor, C_OrtValue): self._tensor = tensor elif isinstance(tensor, OrtTensor): self._tensor = tensor._tensor + elif isinstance(tensor, np.ndarray): + self._tensor = C_OrtValue.ortvalue_from_numpy(tensor, OrtTensor.CPU) else: raise ValueError(f"An OrtValue is expected not {type(tensor)}.") + def __repr__(self) -> str: + "usual" + return f"{self.__class__.__name__}(OrtTensor.from_array({self.numpy()!r}))" + + @property + def device_name(self): + return self._tensor.device_name() + + @property + def ndim(self): + return len(self.shape) + @property def shape(self) -> Tuple[int, ...]: "Returns the shape of the tensor." @@ -139,7 +154,7 @@ def shape(self) -> Tuple[int, ...]: @property def dtype(self) -> Any: "Returns the element type of this tensor." - return self._tensor.element_type() + return tensor_dtype_to_np_dtype(self._tensor.element_type()) @property def key(self) -> Any: diff --git a/requirements-dev.txt b/requirements-dev.txt index 6178c1b..cc2105e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,6 +11,7 @@ ml-dtypes onnxmltools onnxruntime openpyxl +packaging pandas psutil pyquickhelper diff --git a/requirements.txt b/requirements.txt index f017d97..73ee5ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +array_api_compat numpy onnx scipy