diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 746c264..31056a9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.3.1 +++++ +* :pr:`96`: supports local functions in translator * :pr:`95`: improves translation to GraphBuilder 0.3.0 diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py index 6f67dff..b1ad394 100644 --- a/_unittests/ut_translate_api/test_translate_builder.py +++ b/_unittests/ut_translate_api/test_translate_builder.py @@ -1,6 +1,7 @@ import unittest from textwrap import dedent import numpy as np +import onnx.helper as oh from onnx import ModelProto, TensorProto from onnx.checker import check_model from onnx.defs import onnx_opset_version @@ -29,37 +30,43 @@ def test_exp(self): self.assertEqualArray(np.exp(a), got) code = translate(onx, api="builder") - expected = dedent( - """ + expected = ( + dedent( + """ def light_api( op: "GraphBuilder", X: "FLOAT[]", ): - Y = op.Exp(X) + Y = op.Exp(X, outputs=['Y']) op.Identity(Y, outputs=["Y"]) return Y g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") - g.make_tensor_output("Y", TensorProto.FLOAT, ()) + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() """ - ).strip("\n") + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) self.assertEqual(expected, code.strip("\n")) def light_api( op: "GraphBuilder", X: "FLOAT[]", # noqa: F722 ): - Y = op.Exp(X) + Y = op.Exp(X, outputs=["Y"]) op.Identity(Y, outputs=["Y"]) return Y g2 = GraphBuilder({"": 19}) g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) light_api(g2.op, "X") - g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + g2.make_tensor_output( + "Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False + ) onx2 = g2.to_onnx() ref = ReferenceEvaluator(onx2) @@ -78,25 +85,29 @@ def test_zdoc(self): .to_onnx() ) code = translate(onx, api="builder") - expected = dedent( - """ + expected = ( + dedent( + """ def light_api( op: "GraphBuilder", X: "FLOAT[]", ): r = np.array([-1, 1], dtype=np.int64) - r0_0 = op.Reshape(X, r) - Y = op.Transpose(r0_0, perm=[1, 0]) + r0_0 = op.Reshape(X, r, outputs=['r0_0']) + Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y']) op.Identity(Y, outputs=["Y"]) return Y g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") - g.make_tensor_output("Y", TensorProto.FLOAT, ()) + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() """ - ).strip("\n") + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) self.maxDiff = None self.assertEqual(expected, code.strip("\n")) @@ -130,13 +141,14 @@ def test_exp_f(self): tr = Translater(onx, emitter=BuilderEmitter("mm")) code = tr.export(as_str=True) - expected = dedent( - """ + expected = ( + dedent( + """ def light_api( op: "GraphBuilder", X: "FLOAT[]", ): - Y = op.Exp(X) + Y = op.Exp(X, outputs=['Y']) op.Identity(Y, outputs=["Y"]) return Y @@ -145,14 +157,17 @@ def mm() -> "ModelProto": g = GraphBuilder({'': 19}, ir_version=10) g.make_tensor_input("X", TensorProto.FLOAT, ()) light_api(g.op, "X") - g.make_tensor_output("Y", TensorProto.FLOAT, ()) + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) model = g.to_onnx() return model model = mm() """ - ).strip("\n") + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) self.assertEqual(expected, code.strip("\n")) def light_api( @@ -166,7 +181,9 @@ def light_api( g2 = GraphBuilder({"": 19}) g2.make_tensor_input("X", TensorProto.FLOAT, ("A",)) light_api(g2.op, "X") - g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",)) + g2.make_tensor_output( + "Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False + ) onx2 = g2.to_onnx() ref = ReferenceEvaluator(onx2) @@ -174,6 +191,95 @@ def light_api( got = ref.run(None, {"X": a})[0] self.assertEqualArray(np.exp(a), got) + def test_local_function(self): + new_domain = "custom" + + linear_regression = oh.make_function( + new_domain, + "LinearRegression", + ["x", "a", "b"], + ["y"], + [ + oh.make_node("MatMul", ["x", "a"], ["xa"]), + oh.make_node("Add", ["xa", "b"], ["y"]), + ], + [oh.make_opsetid("", 14)], + [], + ) + + graph = oh.make_graph( + [ + oh.make_node( + "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain + ), + oh.make_node("Abs", ["Y1"], ["Y"]), + ], + "example", + [ + oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]), + oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]), + oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), + ], + [oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)], + ) + + onnx_model = oh.make_model( + graph, + opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)], + functions=[linear_regression], + ir_version=10, + ) + tr = Translater(onnx_model, emitter=BuilderEmitter("mm")) + code = tr.export(as_str=True) + + expected = ( + dedent( + """ + def example( + op: "GraphBuilder", + X: "FLOAT[, ]", + A: "FLOAT[, ]", + B: "FLOAT[, ]", + ): + Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1']) + Y = op.Abs(Y1, outputs=['Y']) + op.Identity(Y, outputs=["Y"]) + return Y + + + def make_custom_LinearRegression(g: "GraphBuilder"): + gr = GraphBuilder({'': 14}, as_function=True) + x = gr.make_tensor_input('x') + a = gr.make_tensor_input('a') + b = gr.make_tensor_input('b') + op = gr.op + xa = op.MatMul(x, a, outputs=['xa']) + y = op.Add(xa, b, outputs=['y']) + gr.make_tensor_output(y) + g.add_function(builder=gr) + return gr + + + def mm() -> "ModelProto": + g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10) + g.make_tensor_input("X", TensorProto.FLOAT, ('', '')) + g.make_tensor_input("A", TensorProto.FLOAT, ('', '')) + g.make_tensor_input("B", TensorProto.FLOAT, ('', '')) + example(g.op, "X", "A", "B") + g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__) + make_custom_LinearRegression(g) + model = g.to_onnx() + return model + + + model = mm() + """ + ) + .strip("\n") + .replace("__SUFFIX__", ", is_dimension=False, indexed=False") + ) + self.assertEqual(expected, code.strip("\n")) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index 558c34a..5e414ed 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -194,6 +194,7 @@ def __init__( self._known_shapes = {} self._known_types = {} self.constants_ = {} + self.functions_ = {} elif isinstance(target_opset_or_existing_proto, ModelProto): assert ( not input_names @@ -223,6 +224,8 @@ def __init__( self.constants_[node.output[0]] = node self.set_shape(node.output[0], self._get_tensor_shape(node)) self.set_type(node.output[0], self._get_tensor_type(node)) + for f in proto.functions: + self.add_function(f) else: raise NotImplementedError( f"{type(target_opset_or_existing_proto)} is not supported." @@ -231,6 +234,14 @@ def __init__( self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None self._cache_array = [] + def add_local_function(self, domain: str, name: str, gr: "GraphBuilder"): + "Adds a local function." + assert ( + domain, + name, + ) not in self.functions_, f"Function {(domain, name)} was already added." + self.functions_[domain, name] = gr + def _get_tensor_shape( self, proto: Union[NodeProto, TensorProto] ) -> Tuple[int, ...]: @@ -417,6 +428,8 @@ def make_tensor_output( name: Union[str, List[str]], elem_type: Optional[int] = None, shape: Optional[Tuple[int, ...]] = None, + is_dimension: bool = False, + indexed: bool = False, ) -> Union[str, List[str]]: if isinstance(name, list): res = [] diff --git a/onnx_array_api/translate_api/base_emitter.py b/onnx_array_api/translate_api/base_emitter.py index 62fb318..e8d3811 100644 --- a/onnx_array_api/translate_api/base_emitter.py +++ b/onnx_array_api/translate_api/base_emitter.py @@ -25,6 +25,10 @@ class EventType(IntEnum): END_SIGNATURE = 16 BEGIN_RETURN = 17 END_RETURN = 18 + BEGIN_FUNCTION_SIGNATURE = 19 + END_FUNCTION_SIGNATURE = 20 + BEGIN_FUNCTION_RETURN = 21 + END_FUNCTION_RETURN = 22 @classmethod def to_str(cls, self) -> str: @@ -76,6 +80,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.BEGIN_FUNCTION: return self._emit_begin_function(**kwargs) + if event == EventType.BEGIN_FUNCTION_SIGNATURE: + return self._emit_begin_function_signature(**kwargs) + + if event == EventType.END_FUNCTION_SIGNATURE: + return self._emit_end_function_signature(**kwargs) + if event == EventType.END_FUNCTION: return self._emit_end_function(**kwargs) @@ -100,6 +110,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.END_RETURN: return self._emit_end_return(**kwargs) + if event == EventType.BEGIN_FUNCTION_RETURN: + return self._emit_begin_function_return(**kwargs) + + if event == EventType.END_FUNCTION_RETURN: + return self._emit_end_function_return(**kwargs) + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: @@ -224,6 +240,12 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) + def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError( f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." @@ -250,3 +272,9 @@ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]: return [] + + def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py index 1c893e2..19dd7f9 100644 --- a/onnx_array_api/translate_api/builder_emitter.py +++ b/onnx_array_api/translate_api/builder_emitter.py @@ -41,6 +41,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str: def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: self.opsets = kwargs.get("opsets", {}) self.ir_version = kwargs.get("ir_version", None) + self.function_calls = [] return [] def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -51,7 +52,8 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = [] for inp, stype, shape in self.outputs_full_: outputs.append( - f'g.make_tensor_output("{inp}", TensorProto.{stype}, {shape})' + f'g.make_tensor_output("{inp}", TensorProto.{stype}, ' + f"{shape}, is_dimension=False, indexed=False)" ) rows = [ "", @@ -63,6 +65,7 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: *inputs, f"{self.name}({inps})", *outputs, + *self.function_calls, "model = g.to_onnx()", ] if self.make_model_function: @@ -131,7 +134,8 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]: for init in self.inits: val = to_array(init) stype = str(val.dtype).split(".")[-1] - rows.append(f" {init.name} = np.array({val.tolist()}, dtype=np.{stype})") + name = self._clean_result_name(init.name) + rows.append(f" {name} = np.array({val.tolist()}, dtype=np.{stype})") return rows def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]: @@ -154,11 +158,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: op_type = kwargs["op_type"] inputs = kwargs["inputs"] outputs = kwargs["outputs"] - if kwargs.get("domain", "") != "": - domain = kwargs["domain"] - op_type = f"{domain}.{op_type}" - else: - domain = "" + domain = kwargs.get("domain", "") atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): @@ -167,10 +167,13 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError("Graph attribute not supported yet.") args.append(f"{k}={vatt}") - outs = ", ".join(map(self._clean_result_name, outputs)) + cleaned_outputs = list(map(self._clean_result_name, outputs)) + outs = ", ".join(cleaned_outputs) inps = ", ".join(map(self._clean_result_name, inputs)) op_type = self._emit_node_type(op_type, domain) - sdomain = "" if not domain else f", domain={domain!r}" + # Let's add output names to make it easier to debug. + soutputs = f", outputs={cleaned_outputs}" + sdomain = soutputs if not domain else f", domain={domain!r}{soutputs}" if args: sargs = ", ".join(args) if inps: @@ -186,3 +189,54 @@ def _clean_result_name(self, name): def _emit_node_type(self, op_type, domain): return op_type + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_inputs = [] + self.f_outputs = [] + self.f_inits = [] + self.f_name = kwargs["name"] + self.f_domain = kwargs["domain"] + self.f_attributes = [] + self.f_opsets = kwargs["opsets"] + return [] + + def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_call_name = f"make_{self.f_domain}_{self.f_name}" + return [ + "", + "", + f'def {self.f_call_name}(g: "GraphBuilder"):', + f" gr = GraphBuilder({self.f_opsets}, as_function=True)", + *[f" {name} = gr.make_tensor_input({name!r})" for name in self.f_inputs], + " op = gr.op", + ] + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [" return gr"] + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_inputs.append(kwargs["name"]) + return [] + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + self.f_outputs.append(kwargs["name"]) + return [] + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError("Function attribute are not implemented yet.") + + def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: + self.function_calls.append(f"{self.f_call_name}(g)") + return [ + *[f" gr.make_tensor_output({name})" for name in self.f_outputs], + " g.add_function(builder=gr)", + ] + + def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] diff --git a/onnx_array_api/translate_api/translate.py b/onnx_array_api/translate_api/translate.py index aa78103..81d515a 100644 --- a/onnx_array_api/translate_api/translate.py +++ b/onnx_array_api/translate_api/translate.py @@ -77,6 +77,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain, + opsets={d.domain: d.version for d in self.proto_.opset_import}, ) ) elif isinstance(self.proto_, GraphProto): @@ -96,7 +97,13 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - rows.extend(self.emitter(EventType.BEGIN_SIGNATURE)) + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION_SIGNATURE + if is_function + else EventType.BEGIN_SIGNATURE + ) + ) for i in inputs: if is_function: @@ -119,7 +126,13 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) ) - rows.extend(self.emitter(EventType.END_SIGNATURE)) + rows.extend( + self.emitter( + EventType.END_FUNCTION_SIGNATURE + if is_function + else EventType.END_SIGNATURE + ) + ) for node in nodes: atts = self.extract_attributes(node) @@ -134,7 +147,13 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - rows.extend(self.emitter(EventType.BEGIN_RETURN)) + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION_RETURN + if is_function + else EventType.BEGIN_RETURN + ) + ) for o in outputs: if is_function: @@ -152,7 +171,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - rows.extend(self.emitter(EventType.END_RETURN)) + rows.extend( + self.emitter( + EventType.END_FUNCTION_RETURN if is_function else EventType.END_RETURN + ) + ) if isinstance(self.proto_, (GraphProto, FunctionProto)): name = self.proto_.name