From e29df50c42555c0fe71cef1aef1579f84e2ca866 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:10:00 +0100 Subject: [PATCH 1/6] Improves translation to GraphBuilder --- CHANGELOGS.rst | 5 ++++ onnx_array_api/__init__.py | 2 +- .../translate_api/builder_emitter.py | 30 ++++++++++++++++--- onnx_array_api/translate_api/translate.py | 6 +++- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 3aa613d..5051dc9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.3.1 ++++++ + +* :pr:`94`: improves translation to GraphBuilder + 0.3.0 +++++ diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py index 837bc52..98371ac 100644 --- a/onnx_array_api/__init__.py +++ b/onnx_array_api/__init__.py @@ -2,5 +2,5 @@ APIs to create ONNX Graphs. """ -__version__ = "0.3.0" +__version__ = "0.3.1" __author__ = "Xavier Dupré" diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index a3b38d6..74279ca 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -20,6 +20,10 @@ class BuilderEmitter(BaseEmitter): Converts event into proper code. """ + def __init__(self, make_model_function: str = ""): + super().__init__() + self.make_model_function = make_model_function + def join(self, rows: List[str], single_line: bool = False) -> str: "Join the rows" assert ( @@ -29,6 +33,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str: def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: self.opsets = kwargs.get("opsets", {}) + self.ir_version = kwargs.get("ir_version", None) return [] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -43,12 +48,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: ) rows = [ "", - f"g = GraphBuilder({self.opsets})", + ( + f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})" + if self.ir_version + else f"GraphBuilder({self.opsets})" + ), *inputs, f"{self.name}({inps})", *outputs, "model = g.to_onnx()", ] + if self.make_model_function: + rows = [ + "", + "", + f'def {self.make_model_function}() -> "ModelProto":', + *[" " + _ for _ in rows[1:]], + " return model", + "", + "", + f"model = {self.make_model_function}()", + ] return rows def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -79,12 +99,14 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) if itype == 0: - inp = "X" + inp = name or "X" else: if shape is None: - inp = f'X: "{_itype_to_string(itype)}"' + inp = f'{name}: "{_itype_to_string(itype)}"' else: - inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + inp = ( + f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + ) self.inputs_full.append(inp) self.inputs.append(name) self.inputs_full_.append((name, _itype_to_string(itype), shape)) diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index 7b7480b..aa78103 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} - rows.extend(self.emitter(EventType.START, opsets=opsets)) + rows.extend( + self.emitter( + EventType.START, opsets=opsets, ir_version=self.proto_.ir_version + ) + ) inputs = self.proto_.graph.input outputs = self.proto_.graph.output nodes = self.proto_.graph.node From 9351aca031473e2c61b8ef9329bd954d02e7b07d Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:16:37 +0100 Subject: [PATCH 2/6] ch --- CHANGELOGS.rst | 2 +- .../test_translate_builder.py | 63 ++++++++++++++++++- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 5051dc9..746c264 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.3.1 +++++ -* :pr:`94`: improves translation to GraphBuilder +* :pr:`95`: improves translation to GraphBuilder 0.3.0 +++++ diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 7af0134..7926a3c 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -8,7 +8,8 @@ from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start from onnx_array_api.graph_api import GraphBuilder -from onnx_array_api.translate_api import translate +from onnx_array_api.translate_api import translate, Translater +from onnx_array_api.translate_api.builder_emitter import BuilderEmitter OPSET_API = min(19, onnx_opset_version() - 1) @@ -38,7 +39,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=11) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -89,7 +90,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=11) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -117,6 +118,62 @@ def light_api( self.assertNotEmpty(model) check_model(model) + def test_exp_f(self): + onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Exp", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + tr = Translater(onx, emitter=BuilderEmitter("mm")) + code = tr.export(as_str=True) + + expected = dedent( + """ + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + + def mm() -> "ModelProto": + g = GraphBuilder({'': 19}, ir_version=11) + g.make_tensor_input("X", TensorProto.FLOAT, ()) + light_api(g.op, "X") + g.make_tensor_output("Y", TensorProto.FLOAT, ()) + model = g.to_onnx() + return model + + + model = mm() + """ + ).strip("\n") + self.assertEqual(expected, code.strip("\n")) + + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", # noqa: F722 + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + g2 = GraphBuilder({"": 19}) + g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) + light_api(g2.op, "X") + g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + onx2 = g2.to_onnx() + + ref = ReferenceEvaluator(onx2) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + if __name__ == "__main__": unittest.main(verbosity=2) From 722fd2a51edce5061dbd4230952924c14afcad42 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:22:40 +0100 Subject: [PATCH 3/6] fix issue --- _unittests/ut_translate_api/test_translate_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 7926a3c..0c6452c 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -20,7 +20,7 @@ def setUp(self): self.maxDiff = None def test_exp(self): - onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -69,7 +69,7 @@ def light_api( def test_zdoc(self): onx = ( - start(opset=19) + start(opset=19, ir_version=11) .vin("X") .reshape((-1, 1)) .Transpose(perm=[1, 0]) @@ -119,7 +119,7 @@ def light_api( check_model(model) def test_exp_f(self): - onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) From 432fa69a84145151bb393b9f8f67f4180b94c00e Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:34:47 +0100 Subject: [PATCH 4/6] ir --- .../ut_translate_api/test_translate_builder.py | 12 ++++++------ onnx_array_api/translate_api/builder_emitter.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 0c6452c..6f67dff 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -20,7 +20,7 @@ def setUp(self): self.maxDiff = None def test_exp(self): - onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -39,7 +39,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}, ir_version=11) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -69,7 +69,7 @@ def light_api( def test_zdoc(self): onx = ( - start(opset=19, ir_version=11) + start(opset=19, ir_version=10) .vin("X") .reshape((-1, 1)) .Transpose(perm=[1, 0]) @@ -90,7 +90,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}, ir_version=11) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -119,7 +119,7 @@ def light_api( check_model(model) def test_exp_f(self): - onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -142,7 +142,7 @@ def light_api( def mm() -> "ModelProto": - g = GraphBuilder({'': 19}, ir_version=11) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index 74279ca..3c92206 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -148,6 +148,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: if kwargs.get("domain", "") != "": domain = kwargs["domain"] op_type = f"{domain}.{op_type}" + else: + domain = "" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): @@ -158,9 +160,14 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outs = ", ".join(outputs) inps = ", ".join(inputs) + op_type = self._emit_node_type(op_type, domain) + sdomain = "" if not domain else f", domain={domain!r}" if args: sargs = ", ".join(args) - row = f" {outs} = op.{op_type}({inps}, {sargs})" + row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" else: - row = f" {outs} = op.{op_type}({inps})" + row = f" {outs} = op.{op_type}({inps}{sdomain})" return [row] + + def _emit_node_type(self, op_type, domain): + return op_type From 7260f0e971d10d810caa88db55d89a764699a7ac Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 12:43:03 +0100 Subject: [PATCH 5/6] urls --- .github/workflows/check-urls.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/check-urls.yml b/.github/workflows/check-urls.yml index 67d7731..d56adba 100644 --- a/.github/workflows/check-urls.yml +++ b/.github/workflows/check-urls.yml @@ -42,6 +42,6 @@ jobs: print_all: false timeout: 2 retry_count# : 2 - exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document - exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/ + exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx + exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx # force_pass : true From d6423ed3aa553ceb689f6ba545ad3aea0723d48d Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 18 Feb 2025 13:09:36 +0100 Subject: [PATCH 6/6] check --- .../translate_api/builder_emitter.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index 3c92206..1c893e2 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -4,10 +4,17 @@ from .base_emitter import BaseEmitter _types = { + TensorProto.DOUBLE: "DOUBLE", TensorProto.FLOAT: "FLOAT", TensorProto.FLOAT16: "FLOAT16", TensorProto.INT64: "INT64", TensorProto.INT32: "INT32", + TensorProto.INT16: "INT16", + TensorProto.UINT64: "UINT64", + TensorProto.UINT32: "UINT32", + TensorProto.UINT16: "UINT16", + TensorProto.STRING: "STRING", + TensorProto.BOOL: "BOOL", } @@ -98,6 +105,7 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) + name = self._clean_result_name(name) if itype == 0: inp = name or "X" else: @@ -135,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] + name = self._clean_result_name(name) itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) self.outputs.append(name) @@ -158,16 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError("Graph attribute not supported yet.") args.append(f"{k}={vatt}") - outs = ", ".join(outputs) - inps = ", ".join(inputs) + outs = ", ".join(map(self._clean_result_name, outputs)) + inps = ", ".join(map(self._clean_result_name, inputs)) op_type = self._emit_node_type(op_type, domain) sdomain = "" if not domain else f", domain={domain!r}" if args: sargs = ", ".join(args) - row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" + if inps: + row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" + else: + row = f" {outs} = op.{op_type}({sargs}{sdomain})" else: row = f" {outs} = op.{op_type}({inps}{sdomain})" return [row] + def _clean_result_name(self, name): + return name + def _emit_node_type(self, op_type, domain): return op_type