Skip to content

Implements ArrayAPI #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ Change Logs
0.2.0
+++++

* :pr:`3`: fixes Array API with onnxruntime
* :pr:`17`: implements ArrayAPI
* :pr:`3`: fixes Array API with onnxruntime and scikit-learn
7 changes: 7 additions & 0 deletions _doc/api/array_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
onnx_array_api.array_api
========================

.. toctree::

array_api_onnx_numpy
array_api_onnx_ort
5 changes: 5 additions & 0 deletions _doc/api/array_api_numpy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
onnx_array_api.array_api.onnx_numpy
=============================================

.. automodule:: onnx_array_api.array_api.onnx_numpy
:members:
5 changes: 5 additions & 0 deletions _doc/api/array_api_ort.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
onnx_array_api.array_api.onnx_ort
=================================

.. automodule:: onnx_array_api.array_api.onnx_ort
:members:
1 change: 1 addition & 0 deletions _doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ API
.. toctree::
:maxdepth: 1

array_api
npx_functions
npx_var
npx_jit
Expand Down
29 changes: 27 additions & 2 deletions _doc/api/npx_annot.rst
Original file line number Diff line number Diff line change
@@ -1,29 +1,54 @@
=============
npx.npx_types
=============

DType
=====

.. autoclass:: onnx_array_api.npx.npx_types.DType
:members:

Annotations
+++++++++++
===========

ElemType
++++++++

.. autoclass:: onnx_array_api.npx.npx_types.ElemType
:members:

ParType
+++++++

.. autoclass:: onnx_array_api.npx.npx_types.ParType
:members:

OptParType
++++++++++

.. autoclass:: onnx_array_api.npx.npx_types.OptParType
:members:

TensorType
++++++++++

.. autoclass:: onnx_array_api.npx.npx_types.TensorType
:members:

SequenceType
++++++++++++

.. autoclass:: onnx_array_api.npx.npx_types.SequenceType
:members:

TupleType
+++++++++

.. autoclass:: onnx_array_api.npx.npx_types.TupleType
:members:

Shortcuts
+++++++++
=========

.. autoclass:: onnx_array_api.npx.npx_types.Bool

Expand Down
2 changes: 2 additions & 0 deletions _unittests/test_array_api.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros
20 changes: 20 additions & 0 deletions _unittests/ut_array_api/test_onnx_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
import numpy as np
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.array_api import onnx_numpy as xp
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor


class TestOnnxNumpy(ExtTestCase):
def test_abs(self):
c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64))
mat = xp.zeros(c, dtype=xp.int64)
matnp = mat.numpy()
self.assertEqual(matnp.shape, (4, 5))
self.assertNotEmpty(matnp[0, 0])
a = xp.absolute(mat)
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())


if __name__ == "__main__":
unittest.main(verbosity=2)
93 changes: 88 additions & 5 deletions _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
npxapi_inline,
)
from onnx_array_api.npx.npx_functions import absolute as absolute_inline
from onnx_array_api.npx.npx_functions import all as all_inline
from onnx_array_api.npx.npx_functions import arange as arange_inline
from onnx_array_api.npx.npx_functions import arccos as arccos_inline
from onnx_array_api.npx.npx_functions import arccosh as arccosh_inline
Expand All @@ -50,6 +51,7 @@
from onnx_array_api.npx.npx_functions import det as det_inline
from onnx_array_api.npx.npx_functions import dot as dot_inline
from onnx_array_api.npx.npx_functions import einsum as einsum_inline
from onnx_array_api.npx.npx_functions import equal as equal_inline
from onnx_array_api.npx.npx_functions import erf as erf_inline
from onnx_array_api.npx.npx_functions import exp as exp_inline
from onnx_array_api.npx.npx_functions import expand_dims as expand_dims_inline
Expand Down Expand Up @@ -95,6 +97,7 @@
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
from onnx_array_api.npx.npx_types import (
Bool,
DType,
Float32,
Float64,
Int64,
Expand Down Expand Up @@ -127,18 +130,25 @@ def test_tensor(self):
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
self.assertEmpty(dt.shape)
self.assertEqual(dt.type_name(), "TensorType['float32']")

dt = TensorType["float32"]
self.assertEqual(len(dt.dtypes), 1)
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
self.assertEqual(dt.type_name(), "TensorType['float32']")

dt = TensorType[np.float32]
self.assertEqual(len(dt.dtypes), 1)
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
self.assertEqual(dt.type_name(), "TensorType['float32']")
self.assertEmpty(dt.shape)

dt = TensorType[np.str_]
self.assertEqual(len(dt.dtypes), 1)
self.assertEqual(dt.dtypes[0].dtype, ElemType.str_)
self.assertEqual(dt.type_name(), "TensorType[strings]")
self.assertEmpty(dt.shape)

self.assertRaise(lambda: TensorType[None], TypeError)
self.assertRaise(lambda: TensorType[np.str_], TypeError)
self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError)

def test_superset(self):
Expand Down Expand Up @@ -1155,6 +1165,16 @@ def test_astype(self):
got = ref.run(None, {"A": x})
self.assertEqualArray(z, got[0])

def test_astype_dtype(self):
f = absolute_inline(copy_inline(Input("A")).astype(DType(7)))
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Float64[None]})
x = np.array([[-5.4, 6.6]], dtype=np.float64)
z = np.abs(x.astype(np.int64))
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": x})
self.assertEqualArray(z, got[0])

def test_astype_int(self):
f = absolute_inline(copy_inline(Input("A")).astype(1))
self.assertIsInstance(f, Var)
Expand Down Expand Up @@ -1413,6 +1433,9 @@ def test_einsum(self):
lambda x, y: np.einsum(equation, x, y),
)

def test_equal(self):
self.common_test_inline_bin(equal_inline, np.equal)

@unittest.skipIf(scipy is None, reason="scipy is not installed.")
def test_erf(self):
self.common_test_inline(erf_inline, scipy.special.erf)
Expand Down Expand Up @@ -1460,7 +1483,17 @@ def test_hstack(self):
def test_identity(self):
f = identity_inline(2, dtype=np.float64)
onx = f.to_onnx(constraints={(0, False): Float64[None]})
z = np.identity(2)
self.assertIn('name: "dtype"', str(onx))
z = np.identity(2).astype(np.float64)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {})
self.assertEqualArray(z, got[0])

def test_identity_uint8(self):
f = identity_inline(2, dtype=np.uint8)
onx = f.to_onnx(constraints={(0, False): Float64[None]})
self.assertIn('name: "dtype"', str(onx))
z = np.identity(2).astype(np.uint8)
ref = ReferenceEvaluator(onx)
got = ref.run(None, {})
self.assertEqualArray(z, got[0])
Expand Down Expand Up @@ -2318,7 +2351,7 @@ def compute_labels(X, centers):
self.assertEqual(f.n_versions, 1)
self.assertEqual(len(f.available_versions), 1)
self.assertEqual(f.available_versions, [((np.float64, 2), (np.float64, 2))])
key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2))
key = ((DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2))
onx = f.get_onnx(key)
self.assertIsInstance(onx, ModelProto)
self.assertRaise(lambda: f.get_onnx(2), ValueError)
Expand Down Expand Up @@ -2379,7 +2412,12 @@ def compute_labels(X, centers, use_sqrt=False):
self.assertEqualArray(got[1], dist)
self.assertEqual(f.n_versions, 1)
self.assertEqual(len(f.available_versions), 1)
key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2), "use_sqrt", True)
key = (
(DType(TensorProto.DOUBLE), 2),
(DType(TensorProto.DOUBLE), 2),
"use_sqrt",
True,
)
self.assertEqual(f.available_versions, [key])
onx = f.get_onnx(key)
self.assertIsInstance(onx, ModelProto)
Expand Down Expand Up @@ -2452,7 +2490,52 @@ def test_take(self):
got = ref.run(None, {"A": data, "B": indices})
self.assertEqualArray(y, got[0])

def test_numpy_all(self):
data = np.array([[1, 0], [1, 1]]).astype(np.bool_)
y = np.all(data, axis=1)

f = all_inline(Input("A"), axis=1)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

def test_numpy_all_empty(self):
data = np.zeros((0,), dtype=np.bool_)
y = np.all(data)

f = all_inline(Input("A"))
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

@unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0")
def test_numpy_all_empty_axis_0(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data, axis=0)

f = all_inline(Input("A"), axis=0)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])

def test_numpy_all_empty_axis_1(self):
data = np.zeros((0, 1), dtype=np.bool_)
y = np.all(data, axis=1)

f = all_inline(Input("A"), axis=1)
self.assertIsInstance(f, Var)
onx = f.to_onnx(constraints={"A": Bool[None]})
ref = ReferenceEvaluator(onx)
got = ref.run(None, {"A": data})
self.assertEqualArray(y, got[0])


if __name__ == "__main__":
TestNpx().test_take()
# TestNpx().test_numpy_all_empty_axis_0()
unittest.main(verbosity=2)
5 changes: 4 additions & 1 deletion _unittests/ut_npx/test_sklearn_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from onnx.defs import onnx_opset_version
from sklearn import config_context, __version__ as sklearn_version
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor


Expand All @@ -16,6 +16,7 @@ class TestSklearnArrayAPI(ExtTestCase):
Version(sklearn_version) <= Version("1.2.2"),
reason="reshape ArrayAPI not followed",
)
@ignore_warnings(DeprecationWarning)
def test_sklearn_array_api_linear_discriminant(self):
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
y = np.array([1, 1, 1, 2, 2, 2])
Expand All @@ -26,6 +27,8 @@ def test_sklearn_array_api_linear_discriminant(self):
new_x = EagerNumpyTensor(X)
self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
with config_context(array_api_dispatch=True):
# It fails if scikit-learn <= 1.2.2 because the ArrayAPI
# is not strictly applied.
got = ana.predict(new_x)
self.assertEqualArray(expected, got.numpy())

Expand Down
Loading