diff --git a/.github/workflows/check-urls.yml b/.github/workflows/check-urls.yml index 67d7731..d56adba 100644 --- a/.github/workflows/check-urls.yml +++ b/.github/workflows/check-urls.yml @@ -42,6 +42,6 @@ jobs: print_all: false timeout: 2 retry_count# : 2 - exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document - exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/ + exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx + exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx # force_pass : true diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 3aa613d..746c264 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.3.1 ++++++ + +* :pr:`95`: improves translation to GraphBuilder + 0.3.0 +++++ diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 7af0134..6f67dff 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -8,7 +8,8 @@ from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start from onnx_array_api.graph_api import GraphBuilder -from onnx_array_api.translate_api import translate +from onnx_array_api.translate_api import translate, Translater +from onnx_array_api.translate_api.builder_emitter import BuilderEmitter OPSET_API = min(19, onnx_opset_version() - 1) @@ -19,7 +20,7 @@ def setUp(self): self.maxDiff = None def test_exp(self): - onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() self.assertIsInstance(onx, ModelProto) self.assertIn("Exp", str(onx)) ref = ReferenceEvaluator(onx) @@ -38,7 +39,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -68,7 +69,7 @@ def light_api( def test_zdoc(self): onx = ( - start(opset=19) + start(opset=19, ir_version=10) .vin("X") .reshape((-1, 1)) .Transpose(perm=[1, 0]) @@ -89,7 +90,7 @@ def light_api( op.Identity(Y, outputs=["Y"]) return Y - g = GraphBuilder({'': 19}) + g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") g.make_tensor_output("Y", TensorProto.FLOAT, ()) @@ -117,6 +118,62 @@ def light_api( self.assertNotEmpty(model) check_model(model) + def test_exp_f(self): + onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Exp", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + tr = Translater(onx, emitter=BuilderEmitter("mm")) + code = tr.export(as_str=True) + + expected = dedent( + """ + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + + def mm() -> "ModelProto": + g = GraphBuilder({'': 19}, ir_version=10) + g.make_tensor_input("X", TensorProto.FLOAT, ()) + light_api(g.op, "X") + g.make_tensor_output("Y", TensorProto.FLOAT, ()) + model = g.to_onnx() + return model + + + model = mm() + """ + ).strip("\n") + self.assertEqual(expected, code.strip("\n")) + + def light_api( + op: "GraphBuilder", + X: "FLOAT[]", # noqa: F722 + ): + Y = op.Exp(X) + op.Identity(Y, outputs=["Y"]) + return Y + + g2 = GraphBuilder({"": 19}) + g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) + light_api(g2.op, "X") + g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + onx2 = g2.to_onnx() + + ref = ReferenceEvaluator(onx2) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py index 837bc52..98371ac 100644 --- a/onnx_array_api/__init__.py +++ b/onnx_array_api/__init__.py @@ -2,5 +2,5 @@ APIs to create ONNX Graphs. """ -__version__ = "0.3.0" +__version__ = "0.3.1" __author__ = "Xavier Dupré" diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index a3b38d6..1c893e2 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -4,10 +4,17 @@ from .base_emitter import BaseEmitter _types = { + TensorProto.DOUBLE: "DOUBLE", TensorProto.FLOAT: "FLOAT", TensorProto.FLOAT16: "FLOAT16", TensorProto.INT64: "INT64", TensorProto.INT32: "INT32", + TensorProto.INT16: "INT16", + TensorProto.UINT64: "UINT64", + TensorProto.UINT32: "UINT32", + TensorProto.UINT16: "UINT16", + TensorProto.STRING: "STRING", + TensorProto.BOOL: "BOOL", } @@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter): Converts event into proper code. """ + def __init__(self, make_model_function: str = ""): + super().__init__() + self.make_model_function = make_model_function + def join(self, rows: List[str], single_line: bool = False) -> str: "Join the rows" assert ( @@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str: def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: self.opsets = kwargs.get("opsets", {}) + self.ir_version = kwargs.get("ir_version", None) return [] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: ) rows = [ "", - f"g = GraphBuilder({self.opsets})", + ( + f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})" + if self.ir_version + else f"GraphBuilder({self.opsets})" + ), *inputs, f"{self.name}({inps})", *outputs, "model = g.to_onnx()", ] + if self.make_model_function: + rows = [ + "", + "", + f'def {self.make_model_function}() -> "ModelProto":', + *[" " + _ for _ in rows[1:]], + " return model", + "", + "", + f"model = {self.make_model_function}()", + ] return rows def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) + name = self._clean_result_name(name) if itype == 0: - inp = "X" + inp = name or "X" else: if shape is None: - inp = f'X: "{_itype_to_string(itype)}"' + inp = f'{name}: "{_itype_to_string(itype)}"' else: - inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + inp = ( + f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"' + ) self.inputs_full.append(inp) self.inputs.append(name) self.inputs_full_.append((name, _itype_to_string(itype), shape)) @@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] + name = self._clean_result_name(name) itype = kwargs.get("elem_type", 0) shape = kwargs.get("shape", None) self.outputs.append(name) @@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: if kwargs.get("domain", "") != "": domain = kwargs["domain"] op_type = f"{domain}.{op_type}" + else: + domain = "" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): @@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError("Graph attribute not supported yet.") args.append(f"{k}={vatt}") - outs = ", ".join(outputs) - inps = ", ".join(inputs) + outs = ", ".join(map(self._clean_result_name, outputs)) + inps = ", ".join(map(self._clean_result_name, inputs)) + op_type = self._emit_node_type(op_type, domain) + sdomain = "" if not domain else f", domain={domain!r}" if args: sargs = ", ".join(args) - row = f" {outs} = op.{op_type}({inps}, {sargs})" + if inps: + row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})" + else: + row = f" {outs} = op.{op_type}({sargs}{sdomain})" else: - row = f" {outs} = op.{op_type}({inps})" + row = f" {outs} = op.{op_type}({inps}{sdomain})" return [row] + + def _clean_result_name(self, name): + return name + + def _emit_node_type(self, op_type, domain): + return op_type diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index 7b7480b..aa78103 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} - rows.extend(self.emitter(EventType.START, opsets=opsets)) + rows.extend( + self.emitter( + EventType.START, opsets=opsets, ir_version=self.proto_.ir_version + ) + ) inputs = self.proto_.graph.input outputs = self.proto_.graph.output nodes = self.proto_.graph.node