Skip to content

Supports subgraph in the light API #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.1.3
+++++

* :pr:`48`: support for subgraph in light API
* :pr:`47`: extends export onnx to code to support inner API
* :pr:`46`: adds an export to convert an onnx graph into light API code
* :pr:`45`: fixes light API for operators with two outputs
Expand Down
6 changes: 6 additions & 0 deletions _doc/api/light_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ translate
Classes for the Light API
=========================

ProtoType
+++++++++

.. autoclass:: onnx_array_api.light_api.model.ProtoType
:members:

OnnxGraph
+++++++++

Expand Down
9 changes: 2 additions & 7 deletions _unittests/ut_array_api/test_onnx_numpy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
import unittest
import numpy as np
from onnx import TensorProto
Expand Down Expand Up @@ -91,19 +90,15 @@ def test_arange_int00a(self):
mat = xp.arange(a, b)
matnp = mat.numpy()
self.assertEqual(matnp.shape, (0,))
expected = np.arange(0, 0)
if sys.platform == "win32":
expected = expected.astype(np.int64)
expected = np.arange(0, 0).astype(np.int64)
self.assertEqualArray(matnp, expected)

@ignore_warnings(DeprecationWarning)
def test_arange_int00(self):
mat = xp.arange(0, 0)
matnp = mat.numpy()
self.assertEqual(matnp.shape, (0,))
expected = np.arange(0, 0)
if sys.platform == "win32":
expected = expected.astype(np.int64)
expected = np.arange(0, 0).astype(np.int64)
self.assertEqualArray(matnp, expected)

def test_ones_like_uint16(self):
Expand Down
42 changes: 36 additions & 6 deletions _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest
import sys
from typing import Callable, Optional
import numpy as np
from onnx import ModelProto
from onnx import GraphProto, ModelProto
from onnx.defs import (
get_all_schemas_with_history,
onnx_opset_version,
Expand All @@ -11,8 +10,8 @@
SchemaError,
)
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start, OnnxGraph, Var
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
from onnx_array_api.light_api import start, OnnxGraph, Var, g
from onnx_array_api.light_api._op_var import OpsVar
from onnx_array_api.light_api._op_vars import OpsVars

Expand Down Expand Up @@ -145,7 +144,7 @@ def list_ops_missing(self, n_inputs):
f"{new_missing}\n{text}"
)

@unittest.skipIf(sys.platform == "win32", reason="unstable test on Windows")
@skipif_ci_windows("Unstable on Windows.")
def test_list_ops_missing(self):
self.list_ops_missing(1)
self.list_ops_missing(2)
Expand Down Expand Up @@ -442,7 +441,38 @@ def test_topk_reverse(self):
self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])

def test_if(self):
gg = g().cst(np.array([0], dtype=np.int64)).rename("Z").vout()
onx = gg.to_onnx()
self.assertIsInstance(onx, GraphProto)
self.assertEqual(len(onx.input), 0)
self.assertEqual(len(onx.output), 1)
self.assertEqual([o.name for o in onx.output], ["Z"])
onx = (
start(opset=19)
.vin("X", np.float32)
.ReduceSum()
.rename("Xs")
.cst(np.array([0], dtype=np.float32))
.left_bring("Xs")
.Greater()
.If(
then_branch=g().cst(np.array([1], dtype=np.int64)).rename("Z").vout(),
else_branch=g().cst(np.array([0], dtype=np.int64)).rename("Z").vout(),
)
.rename("W")
.vout()
.to_onnx()
)
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([0, 1, 2, 3, 9, 8, 7, 6], dtype=np.float32)
got = ref.run(None, {"X": x})
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])
got = ref.run(None, {"X": -x})
self.assertEqualArray(np.array([0], dtype=np.int64), got[0])


if __name__ == "__main__":
# TestLightApi().test_topk()
TestLightApi().test_if()
unittest.main(verbosity=2)
56 changes: 54 additions & 2 deletions _unittests/ut_light_api/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start, translate
from onnx_array_api.light_api import start, translate, g
from onnx_array_api.light_api.emitter import EventType

OPSET_API = min(19, onnx_opset_version() - 1)
Expand Down Expand Up @@ -133,7 +133,59 @@ def test_topk_reverse(self):
).strip("\n")
self.assertEqual(expected, code)

def test_export_if(self):
onx = (
start(opset=19)
.vin("X", np.float32)
.ReduceSum()
.rename("Xs")
.cst(np.array([0], dtype=np.float32))
.left_bring("Xs")
.Greater()
.If(
then_branch=g().cst(np.array([1], dtype=np.int64)).rename("Z").vout(),
else_branch=g().cst(np.array([0], dtype=np.int64)).rename("Z").vout(),
)
.rename("W")
.vout()
.to_onnx()
)

self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
k = np.array([2], dtype=np.int64)
got = ref.run(None, {"X": x, "K": k})
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])

code = translate(onx)
selse = "g().cst(np.array([0], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
sthen = "g().cst(np.array([1], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
expected = dedent(
f"""
(
start(opset=19)
.cst(np.array([0.0], dtype=np.float32))
.rename('r')
.vin('X', elem_type=TensorProto.FLOAT)
.bring('X')
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
.rename('Xs')
.bring('Xs', 'r')
.Greater()
.rename('r1_0')
.bring('r1_0')
.If(else_branch={selse}, then_branch={sthen})
.rename('W')
.bring('W')
.vout(elem_type=TensorProto.FLOAT)
.to_onnx()
)"""
).strip("\n")
self.maxDiff = None
self.assertEqual(expected, code)


if __name__ == "__main__":
# TestLightApi().test_topk()
TestTranslate().test_export_if()
unittest.main(verbosity=2)
8 changes: 4 additions & 4 deletions _unittests/ut_light_api/test_translate_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_check_code(self):
outputs.append(make_tensor_value_info("Y", TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
"noname",
"onename",
inputs,
outputs,
initializers,
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_exp(self):
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
'noname',
'light_api',
inputs,
outputs,
initializers,
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_transpose(self):
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
'noname',
'light_api',
inputs,
outputs,
initializers,
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_topk_reverse(self):
outputs.append(make_tensor_value_info('Indices', TensorProto.FLOAT, shape=[]))
graph = make_graph(
nodes,
'noname',
'light_api',
inputs,
outputs,
initializers,
Expand Down
4 changes: 3 additions & 1 deletion _unittests/ut_npx/test_npx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from onnx.reference import ReferenceEvaluator
from onnx.shape_inference import infer_shapes

from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings, skipif_ci_windows
from onnx_array_api.reference import ExtendedReferenceEvaluator
from onnx_array_api.npx import ElemType, eager_onnx, jit_onnx
from onnx_array_api.npx.npx_core_api import (
Expand Down Expand Up @@ -1355,6 +1355,7 @@ def test_clip_none(self):
got = ref.run(None, {"A": x})
self.assertEqualArray(y, got[0])

@skipif_ci_windows("Unstable on Windows.")
def test_arange_inline(self):
# arange(5)
f = arange_inline(Input("A"))
Expand Down Expand Up @@ -1391,6 +1392,7 @@ def test_arange_inline(self):
got = ref.run(None, {"A": x1, "B": x2, "C": x3})
self.assertEqualArray(y, got[0])

@skipif_ci_windows("Unstable on Windows.")
def test_arange_inline_dtype(self):
# arange(1, 5, 2), dtype
f = arange_inline(Input("A"), Input("B"), Input("C"), dtype=np.float64)
Expand Down
11 changes: 8 additions & 3 deletions _unittests/ut_ort/test_ort_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
from onnx_array_api.npx import eager_onnx, jit_onnx
from onnx_array_api.npx.npx_functions import absolute as absolute_inline
from onnx_array_api.npx.npx_functions import cdist as cdist_inline
Expand All @@ -20,6 +20,7 @@


class TestOrtTensor(ExtTestCase):
@skipif_ci_windows("Unstable on Windows")
def test_eager_numpy_type_ort(self):
def impl(A):
self.assertIsInstance(A, EagerOrtTensor)
Expand All @@ -45,6 +46,7 @@ def impl(A):
self.assertEqualArray(z, res.numpy())
self.assertEqual(res.numpy().dtype, np.float64)

@skipif_ci_windows("Unstable on Windows")
def test_eager_numpy_type_ort_op(self):
def impl(A):
self.assertIsInstance(A, EagerOrtTensor)
Expand All @@ -68,6 +70,7 @@ def impl(A):
self.assertEqualArray(z, res.numpy())
self.assertEqual(res.numpy().dtype, np.float64)

@skipif_ci_windows("Unstable on Windows")
def test_eager_ort(self):
def impl(A):
print("A")
Expand Down Expand Up @@ -141,6 +144,7 @@ def impl(A):
self.assertEqual(tuple(res.shape()), z.shape)
self.assertStartsWith("A\nB\nC\n", text)

@skipif_ci_windows("Unstable on Windows")
def test_cdist_com_microsoft(self):
from scipy.spatial.distance import cdist as scipy_cdist

Expand Down Expand Up @@ -193,7 +197,7 @@ def impl(xa, xb):
if len(pieces) > 2:
raise AssertionError(f"Function is not using argument:\n{onx}")

def test_astype(self):
def test_astype_w2(self):
f = absolute_inline(copy_inline(Input("A")).astype(DType(TensorProto.FLOAT)))
onx = f.to_onnx(constraints={"A": Float64[None]})
x = np.array([[-5, 6]], dtype=np.float64)
Expand All @@ -204,7 +208,7 @@ def test_astype(self):
got = ref.run(None, {"A": x})
self.assertEqualArray(z, got[0])

def test_astype0(self):
def test_astype0_w2(self):
f = absolute_inline(copy_inline(Input("A")).astype(DType(TensorProto.FLOAT)))
onx = f.to_onnx(constraints={"A": Float64[None]})
x = np.array(-5, dtype=np.float64)
Expand All @@ -215,6 +219,7 @@ def test_astype0(self):
got = ref.run(None, {"A": x})
self.assertEqualArray(z, got[0])

@skipif_ci_windows("Unstable on Windows")
def test_eager_ort_cast(self):
def impl(A):
return A.astype(DType("FLOAT"))
Expand Down
8 changes: 5 additions & 3 deletions _unittests/ut_ort/test_sklearn_array_api_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from onnx.defs import onnx_opset_version
from sklearn import config_context, __version__ as sklearn_version
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor


Expand All @@ -16,7 +16,8 @@ class TestSklearnArrayAPIOrt(ExtTestCase):
Version(sklearn_version) <= Version("1.2.2"),
reason="reshape ArrayAPI not followed",
)
def test_sklearn_array_api_linear_discriminant(self):
@skipif_ci_windows("Unstable on Windows.")
def test_sklearn_array_api_linear_discriminant_ort(self):
X = np.array(
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float64
)
Expand All @@ -38,7 +39,8 @@ def test_sklearn_array_api_linear_discriminant(self):
Version(sklearn_version) <= Version("1.2.2"),
reason="reshape ArrayAPI not followed",
)
def test_sklearn_array_api_linear_discriminant_float32(self):
@skipif_ci_windows("Unstable on Windows.")
def test_sklearn_array_api_linear_discriminant_ort_float32(self):
X = np.array(
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32
)
Expand Down
5 changes: 2 additions & 3 deletions _unittests/ut_validation/test_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest
import sys
import numpy as np
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
from onnx_array_api.validation.docs import make_euclidean, make_euclidean_skl2onnx


Expand All @@ -27,7 +26,7 @@ def test_make_euclidean_skl2onnx(self):
got = ref.run(None, {"X": X, "Y": Y})[0]
self.assertEqualArray(expected, got)

@unittest.skipIf(sys.platform == "win32", reason="unstable on Windows")
@skipif_ci_windows("Unstable on Windows.")
def test_make_euclidean_np(self):
from onnx_array_api.npx import jit_onnx

Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_xrun_doc/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import time
from onnx_array_api import __file__ as onnx_array_api_file
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ext_test_case import ExtTestCase, is_windows

VERBOSE = 0
ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_array_api_file, "..", "..")))
Expand All @@ -29,7 +29,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
if len(ppath) == 0:
os.environ["PYTHONPATH"] = ROOT
elif ROOT not in ppath:
sep = ";" if sys.platform == "win32" else ":"
sep = ";" if is_windows() else ":"
os.environ["PYTHONPATH"] = ppath + sep + ROOT
perf = time.perf_counter()
try:
Expand Down
Loading