Skip to content

Extends export onnx to code to support inner API #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.1.3
+++++

* :pr:`47`: extends export onnx to code to support inner API
* :pr:`46`: adds an export to convert an onnx graph into light API code
* :pr:`45`: fixes light API for operators with two outputs

Expand Down
14 changes: 13 additions & 1 deletion _doc/api/light_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@ Vars
Classes for the Translater
==========================

BaseEmitter
+++++++++++

.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter
:members:

Emitter
+++++++

.. autoclass:: onnx_array_api.light_api.translate.Emitter
.. autoclass:: onnx_array_api.light_api.emitter.Emitter
:members:

EventType
Expand All @@ -60,6 +66,12 @@ EventType
.. autoclass:: onnx_array_api.light_api.translate.EventType
:members:

InnerEmitter
++++++++++++

.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter
:members:

Translater
++++++++++

Expand Down
Binary file not shown.
290 changes: 290 additions & 0 deletions _unittests/ut_light_api/test_backend_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import unittest
from typing import Any, Dict, List, Optional
from difflib import unified_diff
import packaging.version as pv
import numpy
from numpy.testing import assert_allclose
import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
import onnx.version_converter
from onnx import ModelProto, TensorProto, __version__ as onnx_version
from onnx.helper import (
make_function,
make_graph,
make_model,
make_node,
make_opsetid,
make_tensor_value_info,
)
from onnx.numpy_helper import from_array, to_array
from onnx.backend.base import Device, DeviceType
from onnx_array_api.reference import ExtendedReferenceEvaluator
from onnx_array_api.light_api import translate
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


class ReferenceImplementationError(RuntimeError):
"Fails, export cannot be compared."
pass


class ExportWrapper:
apis = ["onnx", "light"]

def __init__(self, model):
self.model = model
self.expected_sess = ExtendedReferenceEvaluator(self.model)

@property
def input_names(self):
return self.expected_sess.input_names

@property
def input_types(self):
return self.expected_sess.input_types

@property
def output_names(self):
return self.expected_sess.output_names

@property
def output_types(self):
return self.expected_sess.output_types

def run(
self, names: Optional[List[str]], feeds: Optional[Dict[str, Any]] = None
) -> List[Any]:
try:
expected = self.expected_sess.run(names, feeds)
except (RuntimeError, AssertionError, TypeError, KeyError) as e:
raise ReferenceImplementationError(
f"ReferenceImplementation fails with {onnx_simple_text_plot(self.model)}"
f"\n--RAW--\n{self.model}"
) from e

for api in self.apis:
try:
code = translate(self.model, api=api)
except NotImplementedError:
continue
except ValueError as e:
raise AssertionError(
f"Unable to translate model for api {api!r}, "
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
f"\n--EXPECTED--\n{expected}"
) from e
try:
code_compiled = compile(code, "<string>", mode="exec")
except Exception as e:
new_code = "\n".join(
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
)
raise AssertionError(f"ERROR {e}\n{new_code}")

locs = {
"np": numpy,
"to_array": to_array,
"from_array": from_array,
"TensorProto": TensorProto,
"make_function": make_function,
"make_opsetid": make_opsetid,
"make_model": make_model,
"make_graph": make_graph,
"make_node": make_node,
"make_tensor_value_info": make_tensor_value_info,
}
globs = locs.copy()
try:
exec(code_compiled, globs, locs)
except (TypeError, NameError, ValueError) as e:
new_code = "\n".join(
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
)
raise AssertionError(
f"Unable to executed code for api {api!r}\n{new_code}"
) from e
export_model = locs["model"]
ref = ExtendedReferenceEvaluator(export_model)
try:
got = ref.run(names, feeds)
except (TypeError, AttributeError) as e:
diff = "\n".join(
unified_diff(
str(self.model).split("\n"),
str(export_model).split("\n"),
fromfile="before",
tofile="after",
)
)
raise AssertionError(
f"Unable to run the exported model for api {api!r}, "
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
f"\n--CODE--\n{code}"
f"\n--FEEDS--\n{feeds}"
f"\n--EXPECTED--\n{expected}"
f"\n--DIFF--\n{diff}"
) from e
if len(expected) != len(got):
raise AssertionError(
f"Unexpected number of outputs for api {api!r}, "
f"{len(expected)} != {len(got)}."
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
)
for a, b in zip(expected, got):
if not isinstance(a, numpy.ndarray):
continue
if a.shape != b.shape or a.dtype != b.dtype:
raise AssertionError(
f"Shape or type discrepancies for api {api!r}."
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
)
if a.dtype in (numpy.str_, object, numpy.object_) or isinstance(
a.dtype, getattr(getattr(numpy, "dtypes", None), "StrDType", type)
):
if a.tolist() != b.tolist():
raise AssertionError(
f"Text discrepancies for api {api!r} with a.dtype={a.dtype} "
f"and b.dtype={b.dtype}"
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
)
continue
try:
assert_allclose(a, b, atol=1e-3)
except (AssertionError, TypeError) as e:
raise AssertionError(
f"Discrepancies for api {api!r} with a.dtype={a.dtype} "
f"and b.dtype={b.dtype} (type-dtype={type(a.dtype)})"
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
) from e

return expected


class ExportBackendRep(onnx.backend.base.BackendRep):
def __init__(self, session):
self._session = session

def run(self, inputs, **kwargs):
if isinstance(inputs, numpy.ndarray):
inputs = [inputs]
if isinstance(inputs, list):
if len(inputs) == len(self._session.input_names):
feeds = dict(zip(self._session.input_names, inputs))
else:
feeds = {}
pos_inputs = 0
for inp, tshape in zip(
self._session.input_names, self._session.input_types
):
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
if shape == inputs[pos_inputs].shape:
feeds[inp] = inputs[pos_inputs]
pos_inputs += 1
if pos_inputs >= len(inputs):
break
elif isinstance(inputs, dict):
feeds = inputs
else:
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
outs = self._session.run(None, feeds)
return outs


class ExportBackend(onnx.backend.base.Backend):
@classmethod
def is_opset_supported(cls, model): # pylint: disable=unused-argument
return True, ""

@classmethod
def supports_device(cls, device: str) -> bool:
d = Device(device)
return d.type == DeviceType.CPU # type: ignore[no-any-return]

@classmethod
def create_inference_session(cls, model):
return ExportWrapper(model)

@classmethod
def prepare(
cls, model: Any, device: str = "CPU", **kwargs: Any
) -> ExportBackendRep:
if isinstance(model, ExportWrapper):
return ExportBackendRep(model)
if isinstance(model, (str, bytes, ModelProto)):
inf = cls.create_inference_session(model)
return cls.prepare(inf, device, **kwargs)
raise TypeError(f"Unexpected type {type(model)} for model.")

@classmethod
def run_model(cls, model, inputs, device=None, **kwargs):
rep = cls.prepare(model, device, **kwargs)
return rep.run(inputs, **kwargs)

@classmethod
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
raise NotImplementedError("Unable to run the model node by node.")


backend_test = onnx.backend.test.BackendTest(ExportBackend, __name__)

# The following tests are too slow with the reference implementation (Conv).
backend_test.exclude(
"(FLOAT8|BFLOAT16|_opt_|_3d_|_momentum_|_4d_"
"|test_adagrad"
"|test_adam"
"|test_ai_onnx_ml_"
"|test_cast_FLOAT16"
"|test_cast_FLOAT_to_STRING"
"|test_castlike_FLOAT16"
"|test_castlike_FLOAT_to_STRING"
"|test_bernoulli"
"|test_bvlc_alexnet"
"|test_conv" # too long
"|test_gradient_"
"|test_densenet121"
"|test_inception_v1"
"|test_inception_v2"
"|test_loop11_"
"|test_loop16_seq_none"
"|test_MaxPool2d"
"|test_quantizelinear_e"
"|test_resnet50"
"|test_sequence_model"
"|test_scan_sum"
"|test_scatter_with_axis"
"|test_scatter_without_axis"
"|test_shufflenet"
"|test_squeezenet"
"|test_vgg19"
"|test_zfnet512"
")"
)

if pv.Version(onnx_version) < pv.Version("1.16.0"):
backend_test.exclude("(test_strnorm|test_range_)")

# The following tests cannot pass because they consists in generating random number.
backend_test.exclude("(test_bernoulli)")

# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)

if __name__ == "__main__":
res = unittest.main(verbosity=2, exit=False)
tests_run = res.result.testsRun
errors = len(res.result.errors)
skipped = len(res.result.skipped)
unexpected_successes = len(res.result.unexpectedSuccesses)
expected_failures = len(res.result.expectedFailures)
print("---------------------------------")
print(
f"tests_run={tests_run} errors={errors} skipped={skipped} "
f"unexpected_successes={unexpected_successes} "
f"expected_failures={expected_failures}"
)
2 changes: 2 additions & 0 deletions _unittests/ut_light_api/test_light_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import sys
from typing import Callable, Optional
import numpy as np
from onnx import ModelProto
Expand Down Expand Up @@ -144,6 +145,7 @@ def list_ops_missing(self, n_inputs):
f"{new_missing}\n{text}"
)

@unittest.skipIf(sys.platform == "win32", reason="unstable test on Windows")
def test_list_ops_missing(self):
self.list_ops_missing(1)
self.list_ops_missing(2)
Expand Down
8 changes: 8 additions & 0 deletions _unittests/ut_light_api/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start, translate
from onnx_array_api.light_api.emitter import EventType

OPSET_API = min(19, onnx_opset_version() - 1)


class TestTranslate(ExtTestCase):
def test_event_type(self):
self.assertEqual(
EventType.to_str(EventType.INITIALIZER), "EventType.INITIALIZER"
)

def test_exp(self):
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
Expand Down Expand Up @@ -73,6 +79,8 @@ def test_transpose(self):
"""
(
start(opset=19)
.cst(np.array([-1, 1], dtype=np.int64))
.rename('r')
.vin('X', elem_type=TensorProto.FLOAT)
.bring('X', 'r')
.Reshape()
Expand Down
Loading