From 1b49feed8261370c6a14e3a919e1036337b619f6 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 31 Mar 2024 12:16:42 +0200 Subject: [PATCH 1/7] export to builder --- _unittests/ut_translate_api/test_translate.py | 1 - .../test_translate_builder.py | 70 ++++++++++ onnx_array_api/translate_api/__init__.py | 30 +++- onnx_array_api/translate_api/base_emitter.py | 28 ++++ .../translate_api/builder_emitter.py | 132 ++++++++++++++++++ onnx_array_api/translate_api/translate.py | 29 ++-- 6 files changed, 278 insertions(+), 12 deletions(-) create mode 100644 _unittests/ut_translate_api/test_translate_builder.py create mode 100644 onnx_array_api/translate_api/builder_emitter.py diff --git a/_unittests/ut_translate_api/test_translate.py b/_unittests/ut_translate_api/test_translate.py index d505135..0212d0b 100644 --- a/_unittests/ut_translate_api/test_translate.py +++ b/_unittests/ut_translate_api/test_translate.py @@ -221,5 +221,4 @@ def test_aionnxml(self): if __name__ == "__main__": - TestTranslate().test_export_if() unittest.main(verbosity=2) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py new file mode 100644 index 0000000..124013d --- /dev/null +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -0,0 +1,70 @@ +import unittest +from textwrap import dedent +import numpy as np +from onnx import ModelProto, TensorProto +from onnx.defs import onnx_opset_version +from onnx.reference import ReferenceEvaluator +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.light_api import start +from onnx_array_api.graph_api import GraphBuilder +from onnx_array_api.translate_api import translate + + +OPSET_API = min(19, onnx_opset_version() - 1) + + +class TestTranslateBuilder(ExtTestCase): + def setUp(self): + self.maxDiff = None + + def test_exp(self): + onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Exp", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + code = translate(onx, api="builder") + expected = dedent( + """ + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + g = GraphBuilder({'': 19}) + g.make_tensor_input("X", TensorProto.FLOAT, ()) + light_api(g.op, X) + g.make_tensor_output("Y", TensorProto.FLOAT, ()) + model = g.to_onnx() + """ + ).strip("\n") + self.assertEqual(expected, code.strip("\n")) + + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", # noqa: F722 + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + g2 = GraphBuilder({"": 19}) + g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) + light_api(g2.op, "X") + g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + onx2 = g2.to_onnx() + + ref = ReferenceEvaluator(onx2) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/translate_api/__init__.py b/onnx_array_api/translate_api/__init__.py index 25daef6..12b4a77 100644 --- a/onnx_array_api/translate_api/__init__.py +++ b/onnx_array_api/translate_api/__init__.py @@ -1,6 +1,7 @@ from onnx import ModelProto from .translate import Translater from .inner_emitter import InnerEmitter +from .builder_emitter import BuilderEmitter def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str: @@ -14,7 +15,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") default is `"light"` and this is handle by class :class:`onnx_array_api.translate_api.light_emitter.LightEmitter`, another value is `"onnx"` which is the inner API implemented - in onnx package. + in onnx package, `"builder"` follows the syntax for the + class :class:`onnx_array_api.graph_api.GraphBuilder` :return: code .. runpython:: @@ -35,7 +37,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") code = translate(onx) print(code) - The inner API from onnx packahe is also available. + The inner API from onnx package is also available. .. runpython:: :showcode: @@ -54,6 +56,27 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") ) code = translate(onx, api="onnx") print(code) + + The :class:`GraphBuilder + ` API returns this: + + .. runpython:: + :showcode: + + from onnx_array_api.light_api import start + from onnx_array_api.translate_api import translate + + onx = ( + start() + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx, api="builder") + print(code) """ if api == "light": tr = Translater(proto) @@ -61,4 +84,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") if api == "onnx": tr = Translater(proto, emitter=InnerEmitter()) return tr.export(as_str=True) + if api == "builder": + tr = Translater(proto, emitter=BuilderEmitter()) + return tr.export(as_str=True) raise ValueError(f"Unexpected value {api!r} for api.") diff --git a/onnx_array_api/translate_api/base_emitter.py b/onnx_array_api/translate_api/base_emitter.py index 3a0dfb6..62fb318 100644 --- a/onnx_array_api/translate_api/base_emitter.py +++ b/onnx_array_api/translate_api/base_emitter.py @@ -21,6 +21,10 @@ class EventType(IntEnum): FUNCTION_OUTPUT = 12 FUNCTION_ATTRIBUTES = 13 TO_ONNX_FUNCTION = 14 + BEGIN_SIGNATURE = 15 + END_SIGNATURE = 16 + BEGIN_RETURN = 17 + END_RETURN = 18 @classmethod def to_str(cls, self) -> str: @@ -84,6 +88,18 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.FUNCTION_ATTRIBUTES: return self._emit_function_attributes(**kwargs) + if event == EventType.BEGIN_SIGNATURE: + return self._emit_begin_signature(**kwargs) + + if event == EventType.END_SIGNATURE: + return self._emit_end_signature(**kwargs) + + if event == EventType.BEGIN_RETURN: + return self._emit_begin_return(**kwargs) + + if event == EventType.END_RETURN: + return self._emit_end_return(**kwargs) + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: @@ -222,3 +238,15 @@ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError( f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) + + def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py new file mode 100644 index 0000000..963286b --- /dev/null +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, List +from onnx import TensorProto +from .base_emitter import BaseEmitter + +_types = { + TensorProto.FLOAT: "FLOAT", + TensorProto.FLOAT16: "FLOAT16", + TensorProto.INT64: "INT64", + TensorProto.INT32: "INT32", +} + + +def _itype_to_string(itype: int) -> str: + return _types[itype] + + +class BuilderEmitter(BaseEmitter): + """ + Converts event into proper code. + """ + + def join(self, rows: List[str], single_line: bool = False) -> str: + "Join the rows" + assert ( + not single_line + ), f"The emitter {type(self)} does not work with single_line=True." + return "\n".join(rows) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + self.opsets = kwargs.get("opsets", {}) + return [] + + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: + inps = ", ".join(["g.op", *self.inputs]) + inputs = [] + for inp, stype, shape in self.inputs_full_: + inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype}, {shape})') + outputs = [] + for inp, stype, shape in self.outputs_full_: + outputs.append( + f'g.make_tensor_output("{inp}", TensorProto.{stype}, {shape})' + ) + rows = [ + "", + f"g = GraphBuilder({self.opsets})", + *inputs, + f"{self.name}({inps})", + *outputs, + "model = g.to_onnx()", + ] + return rows + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + self.inputs = [] + self.inputs_full = [] + self.outputs = [] + self.inits = [] + self.inputs_full_ = [] + self.outputs_full_ = [] + self.name = kwargs.get("name", "make_graph") + return [] + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + assert False, f"not implemented yet with {kwargs}" + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + itype = kwargs.get("elem_type", 0) + shape = kwargs.get("shape", None) + if itype == 0: + inp = "X" + else: + if shape is None: + inp = f'X: "{_itype_to_string(itype)}"' + else: + inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + self.inputs_full.append(inp) + self.inputs.append(name) + self.inputs_full_.append((name, _itype_to_string(itype), shape)) + return [] + + def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + rows = ["", f"def {self.name}(", ' op: "GraphBuilder",'] + for i in self.inputs_full: + rows.append(f" {i},") + rows.append("):") + return rows + + def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: + outs = ", ".join(self.outputs) + return [f" return {outs}"] + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + itype = kwargs.get("elem_type", 0) + shape = kwargs.get("shape", None) + self.outputs.append(name) + self.outputs_full_.append((name, _itype_to_string(itype), shape)) + return [f' op.Identity({name}, outputs=["{name}"])'] + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + op_type = kwargs["op_type"] + inputs = kwargs["inputs"] + outputs = kwargs["outputs"] + if kwargs.get("domain", "") != "": + domain = kwargs["domain"] + op_type = f"{domain}.{op_type}" + atts = kwargs.get("atts", {}) + args = [] + for k, v in atts.items(): + before, vatt = self.render_attribute_value(v) + if before: + raise NotImplementedError("Graph attribute not supported yet.") + args.append(f"{k}={vatt}") + + outs = ", ".join(outputs) + inps = ", ".join(inputs) + if args: + sargs = ", ".join(args) + row = f" {outs} = op.{op_type}({inps}, {sargs})" + else: + row = f" {outs} = op.{op_type}({inps})" + return [row] diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index 31c1bce..fe81cbf 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -76,18 +76,12 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) else: - rows.extend(self.emitter(EventType.BEGIN_GRAPH)) - - for i in initializers: rows.extend( - self.emitter( - EventType.INITIALIZER, - name=i.name, - init=i, - value=to_array_extended(i), - ) + self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name) ) + rows.extend(self.emitter(EventType.BEGIN_SIGNATURE)) + for i in inputs: if is_function: rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) @@ -109,6 +103,18 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) ) + rows.extend(self.emitter(EventType.END_SIGNATURE)) + + for i in initializers: + rows.extend( + self.emitter( + EventType.INITIALIZER, + name=i.name, + init=i, + value=to_array_extended(i), + ) + ) + for node in nodes: atts = self.extract_attributes(node) rows.extend( @@ -122,6 +128,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) + rows.extend(self.emitter(EventType.BEGIN_RETURN)) + for o in outputs: if is_function: rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) @@ -137,6 +145,9 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ), ) ) + + rows.extend(self.emitter(EventType.END_RETURN)) + if isinstance(self.proto_, (GraphProto, FunctionProto)): name = self.proto_.name else: From af88e8d52cdc8a5d91a2793f1e8a251903db864a Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 31 Mar 2024 12:19:19 +0200 Subject: [PATCH 2/7] doc --- CHANGELOGS.rst | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index f6feee7..ac4ac15 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,10 +1,15 @@ Change Logs =========== -0.2.0 +0.3.0 +++++ +* :pr:`79`: first draft to export to GraphBuilder * :pr:`77`: supports ConcatOfShape and Slice with the light API + +0.2.0 ++++++ + * :pr:`76`, :pr:`79`: add a mode to compare models without execution * :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator * :pr:`71`: adds tools to compare two onnx graphs From 0c2a92d5ef65b70bb54455f8de1d626133234176 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 1 Apr 2024 10:57:35 +0200 Subject: [PATCH 3/7] fix unit test --- _unittests/ut_translate_api/test_translate.py | 6 +++--- _unittests/ut_translate_api/test_translate_classic.py | 4 ++-- onnx_array_api/translate_api/translate.py | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate.py b/_unittests/ut_translate_api/test_translate.py index 0212d0b..924b232 100644 --- a/_unittests/ut_translate_api/test_translate.py +++ b/_unittests/ut_translate_api/test_translate.py @@ -80,9 +80,9 @@ def test_transpose(self): """ ( start(opset=19) + .vin('X', elem_type=TensorProto.FLOAT) .cst(np.array([-1, 1], dtype=np.int64)) .rename('r') - .vin('X', elem_type=TensorProto.FLOAT) .bring('X', 'r') .Reshape() .rename('r0_0') @@ -166,9 +166,9 @@ def test_export_if(self): f""" ( start(opset=19) + .vin('X', elem_type=TensorProto.FLOAT) .cst(np.array([0.0], dtype=np.float32)) .rename('r') - .vin('X', elem_type=TensorProto.FLOAT) .bring('X') .ReduceSum(keepdims=1, noop_with_empty_axes=0) .rename('Xs') @@ -202,9 +202,9 @@ def test_aionnxml(self): """ ( start(opset=19, opsets={'ai.onnx.ml': 3}) + .vin('X', elem_type=TensorProto.FLOAT) .cst(np.array([-1, 1], dtype=np.int64)) .rename('r') - .vin('X', elem_type=TensorProto.FLOAT) .bring('X', 'r') .Reshape() .rename('USE') diff --git a/_unittests/ut_translate_api/test_translate_classic.py b/_unittests/ut_translate_api/test_translate_classic.py index c6cb412..4f1e26b 100644 --- a/_unittests/ut_translate_api/test_translate_classic.py +++ b/_unittests/ut_translate_api/test_translate_classic.py @@ -138,13 +138,13 @@ def test_transpose(self): initializers = [] sparse_initializers = [] functions = [] + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) initializers.append( from_array( np.array([-1, 1], dtype=np.int64), name='r' ) ) - inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( make_node_extended( 'Reshape', @@ -278,13 +278,13 @@ def test_aionnxml(self): initializers = [] sparse_initializers = [] functions = [] + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) initializers.append( from_array( np.array([-1, 1], dtype=np.int64), name='r' ) ) - inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( make_node_extended( 'Reshape', diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index fe81cbf..459d1c2 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -75,6 +75,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: domain=self.proto_.domain, ) ) + elif isinstance(self.proto_, GraphProto): + rows.extend(self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.name)) else: rows.extend( self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name) From 092dfa23498e249f84fad2cd7d2b3eb6b40274a2 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 1 Apr 2024 11:28:20 +0200 Subject: [PATCH 4/7] fix order --- _unittests/ut_translate_api/test_translate.py | 6 ++--- .../test_translate_builder.py | 26 +++++++++++++++++++ onnx_array_api/translate_api/translate.py | 20 +++++++------- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate.py b/_unittests/ut_translate_api/test_translate.py index 924b232..0212d0b 100644 --- a/_unittests/ut_translate_api/test_translate.py +++ b/_unittests/ut_translate_api/test_translate.py @@ -80,9 +80,9 @@ def test_transpose(self): """ ( start(opset=19) - .vin('X', elem_type=TensorProto.FLOAT) .cst(np.array([-1, 1], dtype=np.int64)) .rename('r') + .vin('X', elem_type=TensorProto.FLOAT) .bring('X', 'r') .Reshape() .rename('r0_0') @@ -166,9 +166,9 @@ def test_export_if(self): f""" ( start(opset=19) - .vin('X', elem_type=TensorProto.FLOAT) .cst(np.array([0.0], dtype=np.float32)) .rename('r') + .vin('X', elem_type=TensorProto.FLOAT) .bring('X') .ReduceSum(keepdims=1, noop_with_empty_axes=0) .rename('Xs') @@ -202,9 +202,9 @@ def test_aionnxml(self): """ ( start(opset=19, opsets={'ai.onnx.ml': 3}) - .vin('X', elem_type=TensorProto.FLOAT) .cst(np.array([-1, 1], dtype=np.int64)) .rename('r') + .vin('X', elem_type=TensorProto.FLOAT) .bring('X', 'r') .Reshape() .rename('USE') diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 124013d..c7d813f 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -65,6 +65,32 @@ def light_api( got = ref.run(None, {"X": a})[0] self.assertEqualArray(np.exp(a), got) + def test_zdoc(self): + onx = ( + start() + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx, api="builder") + expected = dedent( + """ + ( + start() + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index 459d1c2..7b7480b 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -82,6 +82,16 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name) ) + for i in initializers: + rows.extend( + self.emitter( + EventType.INITIALIZER, + name=i.name, + init=i, + value=to_array_extended(i), + ) + ) + rows.extend(self.emitter(EventType.BEGIN_SIGNATURE)) for i in inputs: @@ -107,16 +117,6 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: rows.extend(self.emitter(EventType.END_SIGNATURE)) - for i in initializers: - rows.extend( - self.emitter( - EventType.INITIALIZER, - name=i.name, - init=i, - value=to_array_extended(i), - ) - ) - for node in nodes: atts = self.extract_attributes(node) rows.extend( From baa25d85ed6078a29e0856b5bec05b69013ae400 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 2 Apr 2024 10:06:32 +0200 Subject: [PATCH 5/7] fix initializer --- .../test_translate_builder.py | 48 ++++++++++++++----- onnx_array_api/graph_api/graph_builder.py | 12 +++++ .../translate_api/builder_emitter.py | 16 ++++++- 3 files changed, 63 insertions(+), 13 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index c7d813f..448512b 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -2,6 +2,7 @@ from textwrap import dedent import numpy as np from onnx import ModelProto, TensorProto +from onnx.checker import check_model from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase @@ -39,7 +40,7 @@ def light_api( g = GraphBuilder({'': 19}) g.make_tensor_input("X", TensorProto.FLOAT, ()) - light_api(g.op, X) + light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) model = g.to_onnx() """ @@ -78,18 +79,43 @@ def test_zdoc(self): code = translate(onx, api="builder") expected = dedent( """ - ( - start() - .vin("X") - .reshape((-1, 1)) - .Transpose(perm=[1, 0]) - .rename("Y") - .vout() - .to_onnx() - )""" + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", + ): + r = np.array([-1, 1], dtype=np.int64) + r0_0 = op.Reshape(X, r) + Y = op.Transpose(r0_0, perm=[1, 0]) + op.Identity(Y, outputs=["Y"]) + return Y + + g = GraphBuilder({'': 21}) + g.make_tensor_input("X", TensorProto.FLOAT, ()) + light_api(g.op, "X") + g.make_tensor_output("Y", TensorProto.FLOAT, ()) + model = g.to_onnx() + """ ).strip("\n") self.maxDiff = None - self.assertEqual(expected, code) + self.assertEqual(expected, code.strip("\n")) + + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", # noqa: F722 + ): + r = np.array([-1, 1], dtype=np.int64) + r0_0 = op.Reshape(X, r) + Y = op.Transpose(r0_0, perm=[1, 0]) + op.Identity(Y, outputs=["Y"]) + return Y + + g = GraphBuilder({"": 21}) + X = g.make_tensor_input("X", TensorProto.FLOAT, ()) + light_api(g.op, X) + g.make_tensor_output("Y", TensorProto.FLOAT, ()) + model = g.to_onnx() + self.assertNotEmpty(model) + check_model(model) if __name__ == "__main__": diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 800c578..4f5c601 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -119,6 +119,18 @@ def __getattr__(self, name): except AttributeError as e: raise AttributeError(f"Unable to access attribute {name!r}.") from e + def Initializer( + self, init: Union[TensorProto, np.ndarray], name: Optional[str] = None + ) -> str: + """ + Creates an initializer. + + :param init: value + :param name: name if value is not a TensorProto + :return: its name + """ + return self.builder.make_initializer(init, name=name, exists=True) + def make_node( self, op_type: str, diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index 963286b..a3b38d6 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List from onnx import TensorProto +from onnx.numpy_helper import to_array from .base_emitter import BaseEmitter _types = { @@ -31,7 +32,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: return [] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: - inps = ", ".join(["g.op", *self.inputs]) + inps = ", ".join(["g.op", *[f'"{i}"' for i in self.inputs]]) inputs = [] for inp, stype, shape in self.inputs_full_: inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype}, {shape})') @@ -64,7 +65,14 @@ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: return [] def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: - assert False, f"not implemented yet with {kwargs}" + init = kwargs["init"] + if isinstance(init, TensorProto): + assert ( + kwargs["name"] == init.name + ), f"Name mismatch init.name={init.name!r}, name={kwargs['name']!r}" + self.inits.append(init) + return [] + raise AssertionError(f"Unsupported type for an initializer {type(init)}") def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] @@ -90,6 +98,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]: for i in self.inputs_full: rows.append(f" {i},") rows.append("):") + for init in self.inits: + val = to_array(init) + stype = str(val.dtype).split(".")[-1] + rows.append(f" {init.name} = np.array({val.tolist()}, dtype=np.{stype})") return rows def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]: From 906ab71c33377d3f1a616b607be111433f20a239 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 2 Apr 2024 10:13:44 +0200 Subject: [PATCH 6/7] fix ut --- _unittests/ut_translate_api/test_translate_classic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_classic.py b/_unittests/ut_translate_api/test_translate_classic.py index 4f1e26b..c6cb412 100644 --- a/_unittests/ut_translate_api/test_translate_classic.py +++ b/_unittests/ut_translate_api/test_translate_classic.py @@ -138,13 +138,13 @@ def test_transpose(self): initializers = [] sparse_initializers = [] functions = [] - inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) initializers.append( from_array( np.array([-1, 1], dtype=np.int64), name='r' ) ) + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( make_node_extended( 'Reshape', @@ -278,13 +278,13 @@ def test_aionnxml(self): initializers = [] sparse_initializers = [] functions = [] - inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) initializers.append( from_array( np.array([-1, 1], dtype=np.int64), name='r' ) ) + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( make_node_extended( 'Reshape', From 604440dc3e99464ef59d3b662680454b5d88d0f1 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 2 Apr 2024 10:15:41 +0200 Subject: [PATCH 7/7] fix opset --- _unittests/ut_translate_api/test_translate_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 448512b..7af0134 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -68,7 +68,7 @@ def light_api( def test_zdoc(self): onx = ( - start() + start(opset=19) .vin("X") .reshape((-1, 1)) .Transpose(perm=[1, 0]) @@ -89,7 +89,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 21}) + g = GraphBuilder({'': 19}) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ())