diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 1c385ca..055a05e 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.1.3 +++++ +* :pr:`47`: extends export onnx to code to support inner API * :pr:`46`: adds an export to convert an onnx graph into light API code * :pr:`45`: fixes light API for operators with two outputs diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index a50f050..28dc70d 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -48,10 +48,16 @@ Vars Classes for the Translater ========================== +BaseEmitter ++++++++++++ + +.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter + :members: + Emitter +++++++ -.. autoclass:: onnx_array_api.light_api.translate.Emitter +.. autoclass:: onnx_array_api.light_api.emitter.Emitter :members: EventType @@ -60,6 +66,12 @@ EventType .. autoclass:: onnx_array_api.light_api.translate.EventType :members: +InnerEmitter +++++++++++++ + +.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter + :members: + Translater ++++++++++ diff --git a/_unittests/ut_light_api/_data/stft_inlined_batch_1.onnx b/_unittests/ut_light_api/_data/stft_inlined_batch_1.onnx new file mode 100644 index 0000000..172de97 Binary files /dev/null and b/_unittests/ut_light_api/_data/stft_inlined_batch_1.onnx differ diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py new file mode 100644 index 0000000..b0c1cbc --- /dev/null +++ b/_unittests/ut_light_api/test_backend_export.py @@ -0,0 +1,290 @@ +import unittest +from typing import Any, Dict, List, Optional +from difflib import unified_diff +import packaging.version as pv +import numpy +from numpy.testing import assert_allclose +import onnx.backend.base +import onnx.backend.test +import onnx.shape_inference +import onnx.version_converter +from onnx import ModelProto, TensorProto, __version__ as onnx_version +from onnx.helper import ( + make_function, + make_graph, + make_model, + make_node, + make_opsetid, + make_tensor_value_info, +) +from onnx.numpy_helper import from_array, to_array +from onnx.backend.base import Device, DeviceType +from onnx_array_api.reference import ExtendedReferenceEvaluator +from onnx_array_api.light_api import translate +from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + +class ReferenceImplementationError(RuntimeError): + "Fails, export cannot be compared." + pass + + +class ExportWrapper: + apis = ["onnx", "light"] + + def __init__(self, model): + self.model = model + self.expected_sess = ExtendedReferenceEvaluator(self.model) + + @property + def input_names(self): + return self.expected_sess.input_names + + @property + def input_types(self): + return self.expected_sess.input_types + + @property + def output_names(self): + return self.expected_sess.output_names + + @property + def output_types(self): + return self.expected_sess.output_types + + def run( + self, names: Optional[List[str]], feeds: Optional[Dict[str, Any]] = None + ) -> List[Any]: + try: + expected = self.expected_sess.run(names, feeds) + except (RuntimeError, AssertionError, TypeError, KeyError) as e: + raise ReferenceImplementationError( + f"ReferenceImplementation fails with {onnx_simple_text_plot(self.model)}" + f"\n--RAW--\n{self.model}" + ) from e + + for api in self.apis: + try: + code = translate(self.model, api=api) + except NotImplementedError: + continue + except ValueError as e: + raise AssertionError( + f"Unable to translate model for api {api!r}, " + f"\n--BASE--\n{onnx_simple_text_plot(self.model)}" + f"\n--EXPECTED--\n{expected}" + ) from e + try: + code_compiled = compile(code, "", mode="exec") + except Exception as e: + new_code = "\n".join( + [f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))] + ) + raise AssertionError(f"ERROR {e}\n{new_code}") + + locs = { + "np": numpy, + "to_array": to_array, + "from_array": from_array, + "TensorProto": TensorProto, + "make_function": make_function, + "make_opsetid": make_opsetid, + "make_model": make_model, + "make_graph": make_graph, + "make_node": make_node, + "make_tensor_value_info": make_tensor_value_info, + } + globs = locs.copy() + try: + exec(code_compiled, globs, locs) + except (TypeError, NameError, ValueError) as e: + new_code = "\n".join( + [f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))] + ) + raise AssertionError( + f"Unable to executed code for api {api!r}\n{new_code}" + ) from e + export_model = locs["model"] + ref = ExtendedReferenceEvaluator(export_model) + try: + got = ref.run(names, feeds) + except (TypeError, AttributeError) as e: + diff = "\n".join( + unified_diff( + str(self.model).split("\n"), + str(export_model).split("\n"), + fromfile="before", + tofile="after", + ) + ) + raise AssertionError( + f"Unable to run the exported model for api {api!r}, " + f"\n--BASE--\n{onnx_simple_text_plot(self.model)}" + f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}" + f"\n--CODE--\n{code}" + f"\n--FEEDS--\n{feeds}" + f"\n--EXPECTED--\n{expected}" + f"\n--DIFF--\n{diff}" + ) from e + if len(expected) != len(got): + raise AssertionError( + f"Unexpected number of outputs for api {api!r}, " + f"{len(expected)} != {len(got)}." + f"\n--BASE--\n{onnx_simple_text_plot(self.model)}" + f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}" + ) + for a, b in zip(expected, got): + if not isinstance(a, numpy.ndarray): + continue + if a.shape != b.shape or a.dtype != b.dtype: + raise AssertionError( + f"Shape or type discrepancies for api {api!r}." + f"\n--BASE--\n{onnx_simple_text_plot(self.model)}" + f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}" + ) + if a.dtype in (numpy.str_, object, numpy.object_) or isinstance( + a.dtype, getattr(getattr(numpy, "dtypes", None), "StrDType", type) + ): + if a.tolist() != b.tolist(): + raise AssertionError( + f"Text discrepancies for api {api!r} with a.dtype={a.dtype} " + f"and b.dtype={b.dtype}" + f"\n--BASE--\n{onnx_simple_text_plot(self.model)}" + f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}" + ) + continue + try: + assert_allclose(a, b, atol=1e-3) + except (AssertionError, TypeError) as e: + raise AssertionError( + f"Discrepancies for api {api!r} with a.dtype={a.dtype} " + f"and b.dtype={b.dtype} (type-dtype={type(a.dtype)})" + f"\n--BASE--\n{onnx_simple_text_plot(self.model)}" + f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}" + ) from e + + return expected + + +class ExportBackendRep(onnx.backend.base.BackendRep): + def __init__(self, session): + self._session = session + + def run(self, inputs, **kwargs): + if isinstance(inputs, numpy.ndarray): + inputs = [inputs] + if isinstance(inputs, list): + if len(inputs) == len(self._session.input_names): + feeds = dict(zip(self._session.input_names, inputs)) + else: + feeds = {} + pos_inputs = 0 + for inp, tshape in zip( + self._session.input_names, self._session.input_types + ): + shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) + if shape == inputs[pos_inputs].shape: + feeds[inp] = inputs[pos_inputs] + pos_inputs += 1 + if pos_inputs >= len(inputs): + break + elif isinstance(inputs, dict): + feeds = inputs + else: + raise TypeError(f"Unexpected input type {type(inputs)!r}.") + outs = self._session.run(None, feeds) + return outs + + +class ExportBackend(onnx.backend.base.Backend): + @classmethod + def is_opset_supported(cls, model): # pylint: disable=unused-argument + return True, "" + + @classmethod + def supports_device(cls, device: str) -> bool: + d = Device(device) + return d.type == DeviceType.CPU # type: ignore[no-any-return] + + @classmethod + def create_inference_session(cls, model): + return ExportWrapper(model) + + @classmethod + def prepare( + cls, model: Any, device: str = "CPU", **kwargs: Any + ) -> ExportBackendRep: + if isinstance(model, ExportWrapper): + return ExportBackendRep(model) + if isinstance(model, (str, bytes, ModelProto)): + inf = cls.create_inference_session(model) + return cls.prepare(inf, device, **kwargs) + raise TypeError(f"Unexpected type {type(model)} for model.") + + @classmethod + def run_model(cls, model, inputs, device=None, **kwargs): + rep = cls.prepare(model, device, **kwargs) + return rep.run(inputs, **kwargs) + + @classmethod + def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): + raise NotImplementedError("Unable to run the model node by node.") + + +backend_test = onnx.backend.test.BackendTest(ExportBackend, __name__) + +# The following tests are too slow with the reference implementation (Conv). +backend_test.exclude( + "(FLOAT8|BFLOAT16|_opt_|_3d_|_momentum_|_4d_" + "|test_adagrad" + "|test_adam" + "|test_ai_onnx_ml_" + "|test_cast_FLOAT16" + "|test_cast_FLOAT_to_STRING" + "|test_castlike_FLOAT16" + "|test_castlike_FLOAT_to_STRING" + "|test_bernoulli" + "|test_bvlc_alexnet" + "|test_conv" # too long + "|test_gradient_" + "|test_densenet121" + "|test_inception_v1" + "|test_inception_v2" + "|test_loop11_" + "|test_loop16_seq_none" + "|test_MaxPool2d" + "|test_quantizelinear_e" + "|test_resnet50" + "|test_sequence_model" + "|test_scan_sum" + "|test_scatter_with_axis" + "|test_scatter_without_axis" + "|test_shufflenet" + "|test_squeezenet" + "|test_vgg19" + "|test_zfnet512" + ")" +) + +if pv.Version(onnx_version) < pv.Version("1.16.0"): + backend_test.exclude("(test_strnorm|test_range_)") + +# The following tests cannot pass because they consists in generating random number. +backend_test.exclude("(test_bernoulli)") + +# import all test cases at global scope to make them visible to python.unittest +globals().update(backend_test.test_cases) + +if __name__ == "__main__": + res = unittest.main(verbosity=2, exit=False) + tests_run = res.result.testsRun + errors = len(res.result.errors) + skipped = len(res.result.skipped) + unexpected_successes = len(res.result.unexpectedSuccesses) + expected_failures = len(res.result.expectedFailures) + print("---------------------------------") + print( + f"tests_run={tests_run} errors={errors} skipped={skipped} " + f"unexpected_successes={unexpected_successes} " + f"expected_failures={expected_failures}" + ) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index f99a4b5..88c54f8 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -1,4 +1,5 @@ import unittest +import sys from typing import Callable, Optional import numpy as np from onnx import ModelProto @@ -144,6 +145,7 @@ def list_ops_missing(self, n_inputs): f"{new_missing}\n{text}" ) + @unittest.skipIf(sys.platform == "win32", reason="unstable test on Windows") def test_list_ops_missing(self): self.list_ops_missing(1) self.list_ops_missing(2) diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index c1f63f9..8af161c 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -6,11 +6,17 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start, translate +from onnx_array_api.light_api.emitter import EventType OPSET_API = min(19, onnx_opset_version() - 1) class TestTranslate(ExtTestCase): + def test_event_type(self): + self.assertEqual( + EventType.to_str(EventType.INITIALIZER), "EventType.INITIALIZER" + ) + def test_exp(self): onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) @@ -73,6 +79,8 @@ def test_transpose(self): """ ( start(opset=19) + .cst(np.array([-1, 1], dtype=np.int64)) + .rename('r') .vin('X', elem_type=TensorProto.FLOAT) .bring('X', 'r') .Reshape() diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py new file mode 100644 index 0000000..ed51ce3 --- /dev/null +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -0,0 +1,258 @@ +import unittest +import os +from textwrap import dedent +import numpy as np +from onnx import ModelProto, TensorProto, load +from onnx.defs import onnx_opset_version +from onnx.reference import ReferenceEvaluator +from onnx.helper import ( + make_tensor_value_info, + make_node, + make_graph, + make_model, + make_opsetid, +) +from onnx.checker import check_model +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.light_api import start, translate + +OPSET_API = min(19, onnx_opset_version() - 1) + + +class TestTranslateClassic(ExtTestCase): + def test_check_code(self): + opset_imports = [ + make_opsetid("", 19), + ] + inputs = [] + outputs = [] + nodes = [] + initializers = [] + sparse_initializers = [] + functions = [] + inputs.append(make_tensor_value_info("X", TensorProto.FLOAT, shape=[])) + nodes.append(make_node("Exp", ["X"], ["Y"])) + outputs.append(make_tensor_value_info("Y", TensorProto.FLOAT, shape=[])) + graph = make_graph( + nodes, + "noname", + inputs, + outputs, + initializers, + sparse_initializer=sparse_initializers, + ) + model = make_model(graph, functions=functions, opset_imports=opset_imports) + check_model(model) + + def test_exp(self): + onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Exp", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + code = translate(onx, api="onnx") + + expected = dedent( + """ + opset_imports = [ + make_opsetid('', 19), + ] + inputs = [] + outputs = [] + nodes = [] + initializers = [] + sparse_initializers = [] + functions = [] + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) + nodes.append( + make_node( + 'Exp', + ['X'], + ['Y'] + ) + ) + outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[])) + graph = make_graph( + nodes, + 'noname', + inputs, + outputs, + initializers, + sparse_initializer=sparse_initializers, + ) + model = make_model( + graph, + functions=functions, + opset_imports=opset_imports + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + + onx2 = ( + start(opset=19) + .vin("X", elem_type=TensorProto.FLOAT) + .bring("X") + .Exp() + .rename("Y") + .bring("Y") + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + ) + ref = ReferenceEvaluator(onx2) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + def test_transpose(self): + onx = ( + start(opset=19) + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + self.assertIsInstance(onx, ModelProto) + self.assertIn("Transpose", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(a.reshape((-1, 1)).T, got) + + code = translate(onx, api="onnx") + expected = dedent( + """ + opset_imports = [ + make_opsetid('', 19), + ] + inputs = [] + outputs = [] + nodes = [] + initializers = [] + sparse_initializers = [] + functions = [] + initializers.append( + from_array( + np.array([-1, 1], dtype=np.int64), + name='r' + ) + ) + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) + nodes.append( + make_node( + 'Reshape', + ['X', 'r'], + ['r0_0'] + ) + ) + nodes.append( + make_node( + 'Transpose', + ['r0_0'], + ['Y'], + perm=[1, 0] + ) + ) + outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[])) + graph = make_graph( + nodes, + 'noname', + inputs, + outputs, + initializers, + sparse_initializer=sparse_initializers, + ) + model = make_model( + graph, + functions=functions, + opset_imports=opset_imports + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + + def test_topk_reverse(self): + onx = ( + start(opset=19) + .vin("X", np.float32) + .vin("K", np.int64) + .bring("X", "K") + .TopK(largest=0) + .rename("Values", "Indices") + .vout() + .to_onnx() + ) + self.assertIsInstance(onx, ModelProto) + ref = ReferenceEvaluator(onx) + x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32) + k = np.array([2], dtype=np.int64) + got = ref.run(None, {"X": x, "K": k}) + self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0]) + self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1]) + + code = translate(onx, api="onnx") + expected = dedent( + """ + opset_imports = [ + make_opsetid('', 19), + ] + inputs = [] + outputs = [] + nodes = [] + initializers = [] + sparse_initializers = [] + functions = [] + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) + inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[])) + nodes.append( + make_node( + 'TopK', + ['X', 'K'], + ['Values', 'Indices'], + axis=-1, + largest=0, + sorted=1 + ) + ) + outputs.append(make_tensor_value_info('Values', TensorProto.FLOAT, shape=[])) + outputs.append(make_tensor_value_info('Indices', TensorProto.FLOAT, shape=[])) + graph = make_graph( + nodes, + 'noname', + inputs, + outputs, + initializers, + sparse_initializer=sparse_initializers, + ) + model = make_model( + graph, + functions=functions, + opset_imports=opset_imports + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + + def test_fft(self): + data = os.path.join( + os.path.dirname(__file__), "_data", "stft_inlined_batch_1.onnx" + ) + onx = load(data) + code = translate(onx, api="onnx") + try: + compile(code, "", mode="exec") + except Exception as e: + new_code = "\n".join( + [f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))] + ) + raise AssertionError(f"ERROR {e}\n{new_code}") + + +if __name__ == "__main__": + # TestLightApi().test_topk() + unittest.main(verbosity=2) diff --git a/_unittests/ut_reference/test_backend_extended_reference_evaluator.py b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py index 4bc0927..b35fb3c 100644 --- a/_unittests/ut_reference/test_backend_extended_reference_evaluator.py +++ b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py @@ -61,8 +61,6 @@ def create_inference_session(cls, model): def prepare( cls, model: Any, device: str = "CPU", **kwargs: Any ) -> ExtendedReferenceEvaluatorBackendRep: - # if isinstance(model, ExtendedReferenceEvaluatorBackendRep): - # return model if isinstance(model, ExtendedReferenceEvaluator): return ExtendedReferenceEvaluatorBackendRep(model) if isinstance(model, (str, bytes, ModelProto)): diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 89a4ed9..907bb9f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -214,7 +214,7 @@ jobs: - script: pip install onnxmltools --no-deps displayName: 'Install onnxmltools' - script: | - python -m pytest + python -m pytest -v displayName: 'Runs Unit Tests' - script: | python -u setup.py bdist_wheel diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 5e549f9..8969648 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -3,6 +3,7 @@ from .model import OnnxGraph from .translate import Translater from .var import Var, Vars +from .inner_emitter import InnerEmitter def start( @@ -50,13 +51,18 @@ def start( return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function) -def translate(proto: ModelProto, single_line=False) -> str: +def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str: """ Translates an ONNX proto into a code using :ref:`l-light-api` to describe the ONNX graph. :param proto: model to translate :param single_line: as a single line or not + :param api: API to export into, + default is `"light"` and this is handle by class + :class:`onnx_array_api.light_api.emitter.Emitter`, + another value is `"onnx"` which is the inner API implemented + in onnx package. :return: code .. runpython:: @@ -75,9 +81,30 @@ def translate(proto: ModelProto, single_line=False) -> str: ) code = translate(onx) print(code) + + The inner API from onnx packahe is also available. + + .. runpython:: + :showcode: + + from onnx_array_api.light_api import start, translate + + onx = ( + start() + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx, api="onnx") + print(code) """ - tr = Translater(proto) - rows = tr.export() - if single_line: - return ".".join(rows) - return "".join(["(\n ", "\n .".join(rows), "\n)"]) + if api == "light": + tr = Translater(proto) + return tr.export(single_line=single_line, as_str=True) + if api == "onnx": + tr = Translater(proto, emitter=InnerEmitter()) + return tr.export(as_str=True) + raise ValueError(f"Unexpected value {api!r} for api.") diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py new file mode 100644 index 0000000..52d1033 --- /dev/null +++ b/onnx_array_api/light_api/emitter.py @@ -0,0 +1,251 @@ +import inspect +from typing import Any, Dict, List, Tuple +from enum import IntEnum +import numpy as np +from onnx import AttributeProto +from .annotations import ELEMENT_TYPE_NAME + + +class EventType(IntEnum): + START = 0 + INPUT = 1 + OUTPUT = 2 + NODE = 3 + TO_ONNX = 4 + BEGIN_GRAPH = 5 + END_GRAPH = 6 + BEGIN_FUNCTION = 7 + END_FUNCTION = 8 + INITIALIZER = 9 + SPARSE_INITIALIZER = 10 + + @classmethod + def to_str(cls, self) -> str: + for k, v in EventType.__dict__.items(): + if self == v: + return f"{cls.__name__}.{k}" + + +class BaseEmitter: + def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: + """ + Converts an event into an instruction. + + :param event: event kind + :param kwargs: event parameters + :return: list of instructions + """ + + if event == EventType.NODE: + return self._emit_node(**kwargs) + + if event == EventType.INITIALIZER: + return self._emit_initializer(**kwargs) + + if event == EventType.SPARSE_INITIALIZER: + return self._emit_sparse_initializer(**kwargs) + + if event == EventType.INPUT: + return self._emit_input(**kwargs) + + if event == EventType.OUTPUT: + return self._emit_output(**kwargs) + + if event == EventType.START: + return self._emit_start(**kwargs) + + if event == EventType.TO_ONNX: + return self._emit_to_onnx(**kwargs) + + if event == EventType.BEGIN_GRAPH: + return self._emit_begin_graph(**kwargs) + + if event == EventType.END_GRAPH: + return self._emit_end_graph(**kwargs) + + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") + + def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: + """ + Renders an attribute value into a string. + + :param value: value to converter + :return: rows to append before, actual value + """ + v = value[-1] + if value[0].type == AttributeProto.TENSOR: + repl = {"bool": "bool_", "object": "object_", "str": "str_"} + sdtype = repl.get(str(v.dtype), str(str(v.dtype))) + return [], ( + f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), " + f"name={value[0].name!r})" + ) + if isinstance(v, (int, float, list)): + return [], str(v) + if isinstance(v, str): + return [], f"{v!r}" + if isinstance(v, np.ndarray): + if len(v.shape) == 0: + return [], str(v) + if len(v.shape) == 1: + if value[0].type in ( + AttributeProto.INTS, + AttributeProto.FLOATS, + AttributeProto.STRINGS, + ): + return [], str(v.tolist()) + + raise ValueError( + f"Unable to render an attribute {type(v)}, " + f"attribute type={value[0].type}, " + f"dtype={getattr(v, 'dtype', '-')}, " + f"shape={getattr(v, 'shape', '-')}, {value}." + ) + + def join(self, rows: List[str], single_line: bool = False) -> str: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + +class Emitter(BaseEmitter): + """ + Converts event into proper code. + """ + + def join(self, rows: List[str], single_line: bool = False) -> str: + "Join the rows" + if single_line: + return ".".join(rows) + return "".join(["(\n ", "\n .".join(rows), "\n)"]) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + opsets = kwargs.get("opsets", {}) + opset = opsets.get("", None) + if opset is not None: + del opsets[""] + args = [] + if opset: + args.append(f"opset={opset}") + if opsets: + args.append(f"opsets={opsets}") + return [f"start({', '.join(args)})"] + + def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + return ["to_onnx()"] + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + value = kwargs["value"] + repl = {"bool": "bool_", "object": "object_", "str": "str_"} + sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + return [ + f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))", + f"rename({name!r})", + ] + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ] + if elem_type: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" + ] + return [f"vin({name!r})"] + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + inst = [] + if "name" in kwargs: + name = kwargs["name"] + inst.append(f"bring({name!r})") + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + inst.append( + f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ) + elif elem_type: + inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})") + else: + inst.append("vout()") + return inst + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + op_type = kwargs["op_type"] + inputs = kwargs["inputs"] + outputs = kwargs["outputs"] + if kwargs.get("domain", "") != "": + domain = kwargs["domain"] + raise NotImplementedError(f"domain={domain!r} not supported yet.") + atts = kwargs.get("atts", {}) + args = [] + for k, v in atts.items(): + before, vatt = self.render_attribute_value(v) + if before: + raise NotImplementedError("Graph attribute not supported yet.") + args.append(f"{k}={vatt}") + + str_inputs = ", ".join([f"{i!r}" for i in inputs]) + inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] + if len(outputs) == 1: + inst.append(f"rename({outputs[0]!r})") + else: + str_outputs = ", ".join([f"{o!r}" for o in outputs]) + inst.append(f"rename({str_outputs})") + return inst diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py new file mode 100644 index 0000000..6b70246 --- /dev/null +++ b/onnx_array_api/light_api/inner_emitter.py @@ -0,0 +1,142 @@ +from typing import Any, Dict, List, Tuple +from onnx import AttributeProto +from .annotations import ELEMENT_TYPE_NAME +from .emitter import BaseEmitter +from .translate import Translater + + +class InnerEmitter(BaseEmitter): + """ + Converts event into proper code. + """ + + def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: + """ + Renders an attribute value into a string. + + :param value: value to converter + :return: rows to append before, actual value + """ + if value[0].type == AttributeProto.GRAPH: + tr = Translater(value[0].g, emitter=self) + rows = tr.export(as_str=False, single_line=False) + new_rows = [f"def _make_local_graph_{value[0].name}():"] + for line in rows: + if "make_model" in line: + break + new_rows.append(" " + line) + new_rows.append(" return graph") + new_rows.append(f"{value[0].name} = _make_local_graph_{value[0].name}()") + return new_rows, value[0].name + + return super().render_attribute_value(value) + + def join(self, rows: List[str], single_line: bool = False) -> str: + "Returns the separators. `single_line` is unused." + return "\n".join(rows) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = ["opset_imports = ["] + opsets = kwargs.get("opsets", {}) + for k, v in opsets.items(): + lines.append(f" make_opsetid({k!r}, {v!r}),") + lines.append("]") + return lines + + def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "model = make_model(", + " graph,", + " functions=functions,", + " opset_imports=opset_imports", + ")", + ] + return lines + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "inputs = []", + "outputs = []", + "nodes = []", + "initializers = []", + "sparse_initializers = []", + "functions = []", + ] + return lines + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "graph = make_graph(", + " nodes,", + " 'noname',", + " inputs,", + " outputs,", + " initializers,", + " sparse_initializer=sparse_initializers,", + ")", + ] + return lines + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + value = kwargs["value"] + repl = {"bool": "bool_", "object": "object_", "str": "str_"} + sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + return [ + "initializers.append(", + " from_array(", + f" np.array({value.tolist()}, dtype=np.{sdtype}),", + f" name={name!r}", + " )", + ")", + ] + + def _emit_io(self, container: str, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + return [ + f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r}))" + ] + if elem_type: + return [ + f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape=[]))" + ] + return [ + f"{container}.append(make_tensor_value_info({name!r}, TensorProto.UNDEFINED, []))" + ] + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + return self._emit_io("inputs", **kwargs) + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + return self._emit_io("outputs", **kwargs) + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + op_type = kwargs["op_type"] + inputs = kwargs["inputs"] + outputs = kwargs["outputs"] + if kwargs.get("domain", "") != "": + domain = kwargs["domain"] + raise NotImplementedError(f"domain={domain!r} not supported yet.") + + before_lines = [] + lines = [ + "nodes.append(", + " make_node(", + f" {op_type!r},", + f" {inputs},", + f" {outputs},", + ] + domain = kwargs.get("domain", "") + if domain: + lines.append(f" domain={domain!r},") + atts = kwargs.get("atts", {}) + for k, v in atts.items(): + before, value = self.render_attribute_value(v) + before_lines.extend(before) + lines.append(f" {k}={value},") + lines[-1] = lines[-1][:-1] + lines.extend([" )", ")"]) + return before_lines + lines diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index db574df..b42dfc5 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -1,116 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union -from enum import IntEnum import numpy as np from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto from onnx.numpy_helper import to_array -from .annotations import ELEMENT_TYPE_NAME - - -class EventType(IntEnum): - START = 0 - INPUT = 1 - OUTPUT = 2 - NODE = 3 - TO_ONNX = 4 - - -class Emitter: - """ - Converts event into proper code. - """ - - def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: - """ - Converts an event into an instruction. - - :param event: event kind - :param kwargs: event parameters - :return: list of instructions - """ - if event == EventType.START: - opsets = kwargs.get("opsets", {}) - opset = opsets.get("", None) - if opset is not None: - del opsets[""] - args = [] - if opset: - args.append(f"opset={opset}") - if opsets: - args.append(f"opsets={opsets}") - return [f"start({', '.join(args)})"] - - if event == EventType.TO_ONNX: - return ["to_onnx()"] - - if event == EventType.INPUT: - name = kwargs["name"] - elem_type = kwargs.get("elem_type", None) - shape = kwargs.get("shape", None) - if elem_type and shape: - return [ - f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" - ] - if elem_type: - return [ - f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" - ] - return [f"vin({name!r})"] - - if event == EventType.OUTPUT: - inst = [] - if "name" in kwargs: - name = kwargs["name"] - inst.append(f"bring({name!r})") - elem_type = kwargs.get("elem_type", None) - shape = kwargs.get("shape", None) - if elem_type and shape: - inst.append( - f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" - ) - elif elem_type: - inst.append( - f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" - ) - else: - inst.append("vout()") - return inst - - if event == EventType.NODE: - op_type = kwargs["op_type"] - inputs = kwargs["inputs"] - outputs = kwargs["outputs"] - if kwargs.get("domain", "") != "": - domain = kwargs["domain"] - raise NotImplementedError(f"domain={domain!r} not supported yet.") - atts = kwargs.get("atts", {}) - args = [] - for k, v in atts.items(): - args.append(f"{k}={self.render_attribute_value(v)}") - - str_inputs = ", ".join([f"{i!r}" for i in inputs]) - inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] - if len(outputs) == 1: - inst.append(f"rename({outputs[0]!r})") - else: - str_outputs = ", ".join([f"{o!r}" for o in outputs]) - inst.append(f"rename({str_outputs})") - return inst - - raise ValueError(f"Unexpected EventType {event}.") - - def render_attribute_value(self, value: Any) -> str: - """ - Renders an attribute value into a string. - """ - v = value[-1] - if isinstance(v, (int, float, list)): - return str(v) - if isinstance(v, np.ndarray): - if len(v.shape) == 0: - return str(v) - if len(v.shape) == 1: - return str(v.tolist()) - raise ValueError(f"Unable to render an attribute {value}.") +from .emitter import EventType, Emitter class Translater: @@ -124,37 +16,65 @@ def __init__( emitter: Optional[Emitter] = None, ): self.proto_ = proto - self.emit = emitter or Emitter() + self.emitter = emitter or Emitter() def __repr__(self) -> str: return f"{self.__class__.__name__}(<{type(self.proto_)})" - def export(self) -> List[str]: + def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: """ Exports into a code. + :param as_str: as a single string or by rows + :param single_line: tries to compress the output into a single line :return: list of instructions """ rows = [] if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} - rows.extend(self.emit(EventType.START, opsets=opsets)) + rows.extend(self.emitter(EventType.START, opsets=opsets)) inputs = self.proto_.graph.input outputs = self.proto_.graph.output nodes = self.proto_.graph.node + initializers = self.proto_.graph.initializer + sparse_initializers = self.proto_.graph.sparse_initializer elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output nodes = self.proto_.node + if isinstance(self.proto_, GraphProto): + initializers = self.proto_.initializer + sparse_initializers = self.proto_.sparse_initializer + else: + initializers = [] + sparse_initializers = [] else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") + if len(sparse_initializers) != 0: + raise NotImplementedError("Sparse initializer not supported yet.") + + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION + if isinstance(self.proto_, FunctionProto) + else EventType.BEGIN_GRAPH + ) + ) + + for i in initializers: + rows.extend( + self.emitter( + EventType.INITIALIZER, name=i.name, init=i, value=to_array(i) + ) + ) + for i in inputs: if isinstance(i, str): - rows.extend(self.emit(EventType.INPUT, name=i)) + rows.extend(self.emitter(EventType.INPUT, name=i)) else: rows.extend( - self.emit( + self.emitter( EventType.INPUT, name=i.name, elem_type=i.type.tensor_type.elem_type, @@ -168,7 +88,7 @@ def export(self) -> List[str]: for node in nodes: atts = self.extract_attributes(node) rows.extend( - self.emit( + self.emitter( EventType.NODE, op_type=node.op_type, inputs=node.input, @@ -179,11 +99,11 @@ def export(self) -> List[str]: ) for o in outputs: - if isinstance(i, str): - rows.extend(self.emit(EventType.INPUT, name=o)) + if isinstance(o, str): + rows.extend(self.emitter(EventType.INPUT, name=o)) else: rows.extend( - self.emit( + self.emitter( EventType.OUTPUT, name=o.name, elem_type=o.type.tensor_type.elem_type, @@ -193,11 +113,20 @@ def export(self) -> List[str]: ), ) ) + rows.extend( + self.emitter( + EventType.END_FUNCTION + if isinstance(self.proto_, FunctionProto) + else EventType.END_GRAPH + ) + ) if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: raise NotImplementedError("Local functions are not yet implemented.") - rows.extend(self.emit(EventType.TO_ONNX)) + rows.extend(self.emitter(EventType.TO_ONNX)) + if as_str: + return self.emitter.join(rows, single_line=single_line) return rows def extract_attributes( diff --git a/onnx_array_api/plotting/_helper.py b/onnx_array_api/plotting/_helper.py index ddca631..21179ab 100644 --- a/onnx_array_api/plotting/_helper.py +++ b/onnx_array_api/plotting/_helper.py @@ -160,6 +160,8 @@ def _get_type(obj0): if hasattr(obj, "tensor_type"): obj = obj.tensor_type if hasattr(obj, "elem_type"): + if obj.elem_type == 0: + return "NOTENSOR" return tensor_dtype_to_np_dtype(obj.elem_type) raise RuntimeError(f"Unable to guess type from {obj0!r}.") # pragma: no cover