Skip to content

Fixes Array API with onnxruntime #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Change Logs
===========

0.2.0
+++++

* :pr:`3`: fixes Array API with onnxruntime
7 changes: 6 additions & 1 deletion _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
}

epkg_dictionary = {
"Array API": "https://data-apis.org/array-api/",
"ArrayAPI": (
"https://data-apis.org/array-api/",
("2022.12/API_specification/generated/array_api.{0}.html", 1),
),
"DOT": "https://graphviz.org/doc/info/lang.html",
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
"onnx": "https://onnx.ai/onnx/",
Expand All @@ -65,7 +70,7 @@
"numpy": "https://numpy.org/",
"numba": "https://numba.pydata.org/",
"onnx-array-api": (
"http://www.xavierdupre.fr/app/" "onnx-array-api/helpsphinx/index.html"
"http://www.xavierdupre.fr/app/onnx-array-api/helpsphinx/index.html"
),
"pyinstrument": "https://github.com/joerick/pyinstrument",
"python": "https://www.python.org/",
Expand Down
1 change: 1 addition & 0 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ well as to execute it.
tutorial/index
api/index
auto_examples/index
../CHANGELOGS

Sources available on
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_,
Expand Down
19 changes: 6 additions & 13 deletions _unittests/ut_npx/test_sklearn_array_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
import numpy as np
from packaging.version import Version
from onnx.defs import onnx_opset_version
from sklearn import config_context
from sklearn import config_context, __version__ as sklearn_version
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
Expand All @@ -10,23 +11,15 @@
DEFAULT_OPSET = onnx_opset_version()


def take(self, X, indices, *, axis):
# Overwritting method take as it is using iterators.
# When array_api supports `take` we can use this directly
# https://github.com/data-apis/array-api/issues/177
X_np = self._namespace.take(X, indices, axis=axis)
return self._namespace.asarray(X_np)


class TestSklearnArrayAPI(ExtTestCase):
@unittest.skipIf(
Version(sklearn_version) <= Version("1.2.2"),
reason="reshape ArrayAPI not followed",
)
def test_sklearn_array_api_linear_discriminant(self):
from sklearn.utils._array_api import _ArrayAPIWrapper

_ArrayAPIWrapper.take = take
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
y = np.array([1, 1, 1, 2, 2, 2])
ana = LinearDiscriminantAnalysis()
ana = LinearDiscriminantAnalysis()
ana.fit(X, y)
expected = ana.predict(X)

Expand Down
39 changes: 39 additions & 0 deletions _unittests/ut_ort/test_sklearn_array_api_ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
import numpy as np
from packaging.version import Version
from onnx.defs import onnx_opset_version
from sklearn import config_context, __version__ as sklearn_version
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor


DEFAULT_OPSET = onnx_opset_version()


class TestSklearnArrayAPIOrt(ExtTestCase):
@unittest.skipIf(
Version(sklearn_version) <= Version("1.2.2"),
reason="reshape ArrayAPI not followed",
)
def test_sklearn_array_api_linear_discriminant(self):
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
y = np.array([1, 1, 1, 2, 2, 2])
ana = LinearDiscriminantAnalysis()
ana.fit(X, y)
expected = ana.predict(X)

new_x = EagerOrtTensor(OrtTensor.from_array(X))
self.assertEqual(new_x.device_name, "Cpu")
self.assertStartsWith(
"EagerOrtTensor(OrtTensor.from_array(array([[", repr(new_x)
)
with config_context(array_api_dispatch=True):
got = ana.predict(new_x)
self.assertEqualArray(expected, got.numpy())


if __name__ == "__main__":
# import logging
# logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
47 changes: 47 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,53 @@ jobs:
artifactName: 'wheel-linux-wheel-$(python.version)'
targetPath: 'dist'

- job: 'TestLinuxNightly'
pool:
vmImage: 'ubuntu-latest'
strategy:
matrix:
Python310-Linux:
python.version: '3.11'
maxParallel: 3

steps:
- task: UsePythonVersion@0
inputs:
versionSpec: '$(python.version)'
architecture: 'x64'
- script: sudo apt-get update
displayName: 'AptGet Update'
- script: sudo apt-get install -y pandoc
displayName: 'Install Pandoc'
- script: sudo apt-get install -y inkscape
displayName: 'Install Inkscape'
- script: sudo apt-get install -y graphviz
displayName: 'Install Graphviz'
- script: python -m pip install --upgrade pip setuptools wheel
displayName: 'Install tools'
- script: pip install -r requirements.txt
displayName: 'Install Requirements'
- script: pip install -r requirements-dev.txt
displayName: 'Install Requirements dev'
- script: pip uninstall -y scikit-learn
displayName: 'Uninstall scikit-learn'
- script: pip install --pre --extra-index https://pypi.anaconda.org/scipy-wheels-nightly/simple scikit-learn
displayName: 'Install scikit-learn nightly'
- script: pip install onnxmltools --no-deps
displayName: 'Install onnxmltools'
- script: |
ruff .
displayName: 'Ruff'
- script: |
rstcheck -r ./_doc ./onnx_array_api
displayName: 'rstcheck'
- script: |
black --diff .
displayName: 'Black'
- script: |
python -m pytest -v
displayName: 'Runs Unit Tests'

- job: 'TestLinux'
pool:
vmImage: 'ubuntu-latest'
Expand Down
20 changes: 16 additions & 4 deletions onnx_array_api/npx/npx_array_api.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
from typing import Any
from typing import Any, Optional

import numpy as np

from .npx_types import OptParType, ParType, TupleType


class ArrayApiError(RuntimeError):
"""
Raised when a function is not supported by the :epkg:`Array API`.
"""

pass


class ArrayApi:
"""
List of supported method by a tensor.
"""

def __array_namespace__(self):
def __array_namespace__(self, api_version: Optional[str] = None):
"""
Returns the module holding all the available functions.
"""
from onnx_array_api.npx import npx_functions
if api_version is None or api_version == "2022.12":
from onnx_array_api.npx import npx_functions

return npx_functions
return npx_functions
raise ValueError(
f"Unable to return an implementation for api_version={api_version!r}."
)

def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError(
Expand Down
7 changes: 7 additions & 0 deletions onnx_array_api/npx/npx_core_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,10 @@ def npxapi_inline(fn):
to call.
"""
return _xapi(fn, inline=True)


def npxapi_no_inline(fn):
"""
Functions decorated with this decorator are not converted into ONNX.
"""
return fn
33 changes: 30 additions & 3 deletions onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Any, Optional, Tuple, Union

import array_api_compat.numpy as np_array_api
import numpy as np
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.numpy_helper import from_array

from .npx_constants import FUNCTION_DOMAIN
from .npx_core_api import cst, make_tuple, npxapi_inline, var
from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var
from .npx_tensors import ArrayApi
from .npx_types import (
DType,
ElemType,
OptParType,
ParType,
Expand Down Expand Up @@ -397,6 +399,17 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]:
return v


@npxapi_no_inline
def isdtype(
dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]]
) -> bool:
"""
See :epkg:`ArrayAPI:isdtype`.
This function is not converted into an onnx graph.
"""
return np_array_api.isdtype(dtype, kind)


@npxapi_inline
def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T"]:
"See :func:`numpy.isnan`."
Expand Down Expand Up @@ -460,9 +473,23 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics,

@npxapi_inline
def reshape(
x: TensorType[ElemType.numerics, "T"], shape: TensorType[ElemType.int64, "I"]
x: TensorType[ElemType.numerics, "T"],
shape: TensorType[ElemType.int64, "I", (None,)],
) -> TensorType[ElemType.numerics, "T"]:
"See :func:`numpy.reshape`."
"""
See :func:`numpy.reshape`.

.. warning::

Numpy definition is tricky because onnxruntime does not handle well
dimensions with an undefined number of dimensions.
However the array API defines a more stricly signature for
`reshape <https://data-apis.org/array-api/2022.12/
API_specification/generated/array_api.reshape.html>`_.
:epkg:`scikit-learn` updated its code to follow the Array API in
`PR 26030 ENH Forces shape to be tuple when using Array API's reshape
<https://github.com/scikit-learn/scikit-learn/pull/26030>`_.
"""
if isinstance(shape, int):
shape = cst(np.array([shape], dtype=np.int64))
shape_reshaped = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape")
Expand Down
17 changes: 17 additions & 0 deletions onnx_array_api/npx/npx_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,23 @@ def to_onnx(
node_inputs.append(input_name)
continue

if isinstance(i, tuple) and all(map(lambda x: isinstance(x, int), i)):
ai = np.array(list(i), dtype=np.int64)
c = Cst(ai)
input_name = self._unique(var._prefix)
self._id_vars[id(i), index] = input_name
self._id_vars[id(c), index] = input_name
self.make_node(
"Constant",
[],
[input_name],
value=from_array(ai),
opset=self.target_opsets[""],
)
self.onnx_names_[input_name] = c
node_inputs.append(input_name)
continue

raise NotImplementedError(
f"Unexpected type {type(i)} for node={domop}."
)
Expand Down
32 changes: 23 additions & 9 deletions onnx_array_api/npx/npx_jit_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,18 @@ def jit_call(self, *values, **kwargs):
try:
res = fct.run(*values)
except Exception as e:
from ..plotting.text_plot import onnx_simple_text_plot

text = onnx_simple_text_plot(self.onxs[key])
raise RuntimeError(
f"Unable to run function for key={key!r}, "
f"types={[type(x) for x in values]}, "
f"dtypes={[x.dtype for x in values]}, "
f"shapes={[x.shape for x in values]}, "
f"kwargs={kwargs}, "
f"self.input_to_kwargs_={self.input_to_kwargs_}, "
f"onnx={self.onxs[key]}."
f"f={self.f} from module {self.f.__module__!r} "
f"onnx=\n---\n{text}\n---\n{self.onxs[key]}"
) from e
self.info("-", "jit_call")
return res
Expand Down Expand Up @@ -492,14 +498,22 @@ def _preprocess_constants(self, *args):
elif isinstance(n, Cst):
new_args.append(self.tensor_class(n.inputs[0]))
modified = True
# elif isinstance(n, tuple):
# if any(map(lambda t: isinstance(t, Var), n)):
# raise TypeError(
# f"Unexpected types in tuple "
# f"({[type(t) for t in n]}) for input {i}, "
# f"function {self.f} from module {self.f.__module__!r}."
# )
# new_args.append(n)
elif isinstance(n, tuple):
if all(map(lambda x: isinstance(x, int), n)):
new_args.append(
self.tensor_class(np.array(list(n), dtype=np.int64))
)
elif any(map(lambda t: isinstance(t, Var), n)):
raise TypeError(
f"Unexpected types in tuple "
f"({[type(t) for t in n]}) for input {i}, "
f"function {self.f} from module {self.f.__module__!r}."
)
else:
raise TypeError(
f"Unsupported tuple {n!r} for input {i}, "
f"function {self.f} from module {self.f.__module__!r}."
)
elif isinstance(n, np.ndarray):
new_args.append(self.tensor_class(n))
modified = True
Expand Down
6 changes: 0 additions & 6 deletions onnx_array_api/npx/npx_numpy_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,6 @@ def get_ir_version(cls, ir_version):
"""
return ir_version

def const_cast(self, to: Any = None) -> "EagerTensor":
"""
Casts a constant without any ONNX conversion.
"""
return self.__class__(self._tensor.astype(to))

# The class should support whatever Var supports.
# This part is not yet complete.

Expand Down
Loading