diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c3c667d..39aaea9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`60`: supports translation of local functions * :pr:`59`: add methods to update nodes in GraphAPI 0.1.3 diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 544b35f..15342c1 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -16,6 +16,13 @@ translate .. autofunction:: onnx_array_api.light_api.translate +make_helper ++++++++++++ + +.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended + +.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute + Classes for the Light API ========================= @@ -68,19 +75,13 @@ Classes for the Translater BaseEmitter +++++++++++ -.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter - :members: - -Emitter -+++++++ - -.. autoclass:: onnx_array_api.light_api.emitter.Emitter +.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter :members: EventType +++++++++ -.. autoclass:: onnx_array_api.light_api.translate.EventType +.. autoclass:: onnx_array_api.light_api.base_emitter.EventType :members: InnerEmitter @@ -89,6 +90,12 @@ InnerEmitter .. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter :members: +LightEmitter +++++++++++++ + +.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter + :members: + Translater ++++++++++ diff --git a/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx b/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx new file mode 100644 index 0000000..8116ec3 Binary files /dev/null and b/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx differ diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index b0c1cbc..f597d21 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -1,3 +1,4 @@ +import sys import unittest from typing import Any, Dict, List, Optional from difflib import unified_diff @@ -17,12 +18,16 @@ make_opsetid, make_tensor_value_info, ) +from onnx.reference.op_run import to_array_extended from onnx.numpy_helper import from_array, to_array from onnx.backend.base import Device, DeviceType from onnx_array_api.reference import ExtendedReferenceEvaluator +from onnx_array_api.light_api.make_helper import make_node_extended from onnx_array_api.light_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot +verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0 + class ReferenceImplementationError(RuntimeError): "Fails, export cannot be compared." @@ -34,7 +39,7 @@ class ExportWrapper: def __init__(self, model): self.model = model - self.expected_sess = ExtendedReferenceEvaluator(self.model) + self.expected_sess = ExtendedReferenceEvaluator(self.model, verbose=verbosity) @property def input_names(self): @@ -85,6 +90,7 @@ def run( locs = { "np": numpy, "to_array": to_array, + "to_array_extended": to_array_extended, "from_array": from_array, "TensorProto": TensorProto, "make_function": make_function, @@ -92,6 +98,7 @@ def run( "make_model": make_model, "make_graph": make_graph, "make_node": make_node, + "make_node_extended": make_node_extended, "make_tensor_value_info": make_tensor_value_info, } globs = locs.copy() @@ -105,7 +112,7 @@ def run( f"Unable to executed code for api {api!r}\n{new_code}" ) from e export_model = locs["model"] - ref = ExtendedReferenceEvaluator(export_model) + ref = ExtendedReferenceEvaluator(export_model, verbose=verbosity) try: got = ref.run(names, feeds) except (TypeError, AttributeError) as e: diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index c2b2c70..9974f81 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -6,7 +6,7 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start, translate, g -from onnx_array_api.light_api.emitter import EventType +from onnx_array_api.light_api.base_emitter import EventType OPSET_API = min(19, onnx_opset_version() - 1) diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index cb7d6a4..4d52183 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -5,6 +5,7 @@ from onnx import ModelProto, TensorProto, load from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun from onnx.helper import ( make_tensor_value_info, make_node, @@ -68,7 +69,7 @@ def test_exp(self): functions = [] inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Exp', ['X'], ['Y'] @@ -144,14 +145,14 @@ def test_transpose(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['r0_0'] ) ) nodes.append( - make_node( + make_node_extended( 'Transpose', ['r0_0'], ['Y'], @@ -210,7 +211,7 @@ def test_topk_reverse(self): inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[])) nodes.append( - make_node( + make_node_extended( 'TopK', ['X', 'K'], ['Values', 'Indices'], @@ -264,7 +265,6 @@ def test_aionnxml(self): .to_onnx() ) code = translate(onx, api="onnx") - print(code) expected = dedent( """ opset_imports = [ @@ -285,14 +285,14 @@ def test_aionnxml(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['USE'] ) ) nodes.append( - make_node( + make_node_extended( 'Normalizer', ['USE'], ['Y'], @@ -318,7 +318,115 @@ def test_aionnxml(self): self.maxDiff = None self.assertEqual(expected, code) + @classmethod + def _code_line(cls, code): + lines = code.split("\n") + return "\n".join(f"{i+1:03d} {line}" for i, line in enumerate(lines)) + + @classmethod + def _run(cls, code): + try: + code_compiled = compile(code, "", mode="exec") + except Exception as e: + raise AssertionError( + f"Compilation failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + + import onnx + import onnx.helper + import onnx.numpy_helper + import onnx_array_api.light_api.make_helper + import onnx.reference.custom_element_types + + def from_array_extended(tensor, name=None): + dt = tensor.dtype + if ( + dt == onnx.reference.custom_element_types.float8e4m3fn + and dt.descr[0][0] == "e4m3fn" + ): + to = TensorProto.FLOAT8E4M3FN + dt_to = np.uint8 + elif ( + dt == onnx.reference.custom_element_types.bfloat16 + and dt.descr[0][0] == "bfloat16" + ): + to = TensorProto.BFLOAT16 + dt_to = np.uint16 + else: + return onnx.numpy_helper.from_array(tensor, name) + + t = onnx.numpy_helper.from_array(tensor.astype(dt_to), name) + t.data_type = to + return t + + globs = onnx.__dict__.copy() + globs.update(onnx.helper.__dict__) + globs.update(onnx.numpy_helper.__dict__) + globs.update(onnx_array_api.light_api.make_helper.__dict__) + globs.update(onnx.reference.custom_element_types.__dict__) + globs["from_array_extended"] = from_array_extended + locs = {} + try: + exec(code_compiled, globs, locs) + except Exception as e: + raise AssertionError( + f"Execution failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + return globs, locs + + def test_remove_nodes(self): + path = os.path.join( + os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx" + ) + onx = load(path) + code = translate(onx, api="onnx") + _, locs = self._run(code) + self.assertIn("model", locs) + model = locs["model"] + x = np.arange(4).reshape((-1, 2)).astype(np.float32) + feeds = {"X": x} + + class CustomGemmFloat8E4M3FN(OpRun): + op_domain = "onnx_extented.ortops.tutorial.cpu" + + def _run( + self, + x, + y, + bias=None, + scale_x=None, + scale_y=None, + scale_z=None, + transA=False, + transB=False, + dtype=None, + rowMajor=None, + computeType=None, + ): + if scale_x is not None: + x = x * scale_x + if transA: + x = x.T + if scale_y is not None: + y = y * scale_y + if transB: + y = y.T + z = x @ y + if bias is not None: + z += bias + if scale_z is not None: + z = z / scale_z + return (z,) + + ref = ReferenceEvaluator(onx, new_ops=[CustomGemmFloat8E4M3FN]) + expected = ref.run(None, feeds)[0] + ref2 = ReferenceEvaluator(model, new_ops=[CustomGemmFloat8E4M3FN]) + got = ref2.run(None, feeds)[0] + self.assertEqualArray(expected, got) + + # with open("debug_test_remove_nodes.py", "w") as f: + # f.write(code) + if __name__ == "__main__": - # TestLightApi().test_topk() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index be6e9dd..558e626 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") :param single_line: as a single line or not :param api: API to export into, default is `"light"` and this is handle by class - :class:`onnx_array_api.light_api.emitter.Emitter`, + :class:`onnx_array_api.light_api.light_emitter.LightEmitter`, another value is `"onnx"` which is the inner API implemented in onnx package. :return: code diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/base_emitter.py similarity index 57% rename from onnx_array_api/light_api/emitter.py rename to onnx_array_api/light_api/base_emitter.py index a1b0e40..3a0dfb6 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/base_emitter.py @@ -1,9 +1,8 @@ import inspect -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from enum import IntEnum import numpy as np from onnx import AttributeProto -from .annotations import ELEMENT_TYPE_NAME class EventType(IntEnum): @@ -11,13 +10,17 @@ class EventType(IntEnum): INPUT = 1 OUTPUT = 2 NODE = 3 - TO_ONNX = 4 + TO_ONNX_MODEL = 4 BEGIN_GRAPH = 5 END_GRAPH = 6 BEGIN_FUNCTION = 7 END_FUNCTION = 8 INITIALIZER = 9 SPARSE_INITIALIZER = 10 + FUNCTION_INPUT = 11 + FUNCTION_OUTPUT = 12 + FUNCTION_ATTRIBUTES = 13 + TO_ONNX_FUNCTION = 14 @classmethod def to_str(cls, self) -> str: @@ -54,8 +57,11 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.START: return self._emit_start(**kwargs) - if event == EventType.TO_ONNX: - return self._emit_to_onnx(**kwargs) + if event == EventType.TO_ONNX_MODEL: + return self._emit_to_onnx_model(**kwargs) + + if event == EventType.TO_ONNX_FUNCTION: + return self._emit_to_onnx_function(**kwargs) if event == EventType.BEGIN_GRAPH: return self._emit_begin_graph(**kwargs) @@ -63,6 +69,21 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.END_GRAPH: return self._emit_end_graph(**kwargs) + if event == EventType.BEGIN_FUNCTION: + return self._emit_begin_function(**kwargs) + + if event == EventType.END_FUNCTION: + return self._emit_end_function(**kwargs) + + if event == EventType.FUNCTION_INPUT: + return self._emit_function_input(**kwargs) + + if event == EventType.FUNCTION_OUTPUT: + return self._emit_function_output(**kwargs) + + if event == EventType.FUNCTION_ATTRIBUTES: + return self._emit_function_attributes(**kwargs) + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: @@ -104,11 +125,27 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: srows = ".".join(rows[:-1]) return [], f"g().{srows}" + if isinstance(value, tuple) and len(value) == 2 and value[1] is None: + # in a function, an attribute receiving a value from an attribute + v = value[0] + name = v.name + ref = v.ref_attr_name + dt = v.type + return [], self._make_attribute(name=name, ref_attr_name=ref, attr_type=dt) + raise ValueError( f"Unable to render an attribute {type(v)}, " f"attribute type={value[0].type}, " f"dtype={getattr(v, 'dtype', '-')}, " - f"shape={getattr(v, 'shape', '-')}, {value}." + f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " + f"value={value!r}." + ) + + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) def join(self, rows: List[str], single_line: bool = False) -> str: @@ -121,7 +158,12 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: raise NotImplementedError( f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) @@ -161,100 +203,22 @@ def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) -class Emitter(BaseEmitter): - """ - Converts event into proper code. - """ - - def join(self, rows: List[str], single_line: bool = False) -> str: - "Join the rows" - if single_line: - return ".".join(rows) - return "".join(["(\n ", "\n .".join(rows), "\n)"]) - - def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: - opsets = kwargs.get("opsets", {}) - opset = opsets.get("", None) - if opset is not None: - del opsets[""] - args = [] - if opset: - args.append(f"opset={opset}") - if opsets: - args.append(f"opsets={opsets}") - return [f"start({', '.join(args)})"] - - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: - return ["to_onnx()"] - - def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - return [] - - def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - return [] - - def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: - name = kwargs["name"] - value = kwargs["value"] - repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(value.dtype), str(str(value.dtype))) - return [ - f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))", - f"rename({name!r})", - ] - - def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: - name = kwargs["name"] - elem_type = kwargs.get("elem_type", None) - shape = kwargs.get("shape", None) - if elem_type and shape: - return [ - f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" - ] - if elem_type: - return [ - f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" - ] - return [f"vin({name!r})"] + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) - def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: - inst = [] - if "name" in kwargs: - name = kwargs["name"] - inst.append(f"bring({name!r})") - elem_type = kwargs.get("elem_type", None) - shape = kwargs.get("shape", None) - if elem_type and shape: - inst.append( - f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" - ) - elif elem_type: - inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})") - else: - inst.append("vout()") - return inst + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) - def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: - op_type = kwargs["op_type"] - inputs = kwargs["inputs"] - outputs = kwargs["outputs"] - if kwargs.get("domain", "") != "": - domain = kwargs["domain"] - op_type = f"{domain}.{op_type}" - atts = kwargs.get("atts", {}) - args = [] - for k, v in atts.items(): - before, vatt = self.render_attribute_value(v) - if before: - raise NotImplementedError("Graph attribute not supported yet.") - args.append(f"{k}={vatt}") - - str_inputs = ", ".join([f"{i!r}" for i in inputs]) - inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] - if len(outputs) == 1: - inst.append(f"rename({outputs[0]!r})") - else: - str_outputs = ", ".join([f"{o!r}" for o in outputs]) - inst.append(f"rename({str_outputs})") - return inst + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index f5d5e4d..72ee725 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from onnx import AttributeProto from .annotations import ELEMENT_TYPE_NAME -from .emitter import BaseEmitter +from .base_emitter import BaseEmitter from .translate import Translater @@ -31,6 +31,15 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: return super().render_attribute_value(value) + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + if ref_attr_name is None: + raise NotImplementedError( + f"Cannot create attribute with name={name!r}, attr_type={attr_type}." + ) + return f"make_ref_attribute(key={name!r}, attr_type={attr_type}, ref_attr_name={ref_attr_name!r})" + def join(self, rows: List[str], single_line: bool = False) -> str: "Returns the separators. `single_line` is unused." return "\n".join(rows) @@ -43,7 +52,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: lines.append("]") return lines - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "model = make_model(", " graph,", @@ -82,11 +91,22 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] value = kwargs["value"] repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + fra = "from_array" + sdtype = repl.get(str(value.dtype), str(value.dtype)) + if sdtype.startswith("("): + from onnx.reference.custom_element_types import float8e4m3fn + + if sdtype == str(float8e4m3fn): + sdtype = "float8e4m3fn" + fra = "from_array_extended" + else: + raise NotImplementedError(f"Unexpected dtype={sdtype}.") + else: + sdtype = f"np.{sdtype}" return [ "initializers.append(", - " from_array(", - f" np.array({value.tolist()}, dtype=np.{sdtype}),", + f" {fra}(", + f" np.array({value.tolist()}, dtype={sdtype}),", f" name={name!r}", " )", ")", @@ -124,7 +144,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: before_lines = [] lines = [ "nodes.append(", - " make_node(", + " make_node_extended(", f" {op_type!r},", f" {inputs},", f" {outputs},", @@ -140,3 +160,46 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: lines[-1] = lines[-1][:-1] lines.extend([" )", ")"]) return before_lines + lines + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "", + f"name_f = {kwargs['name']!r}", + f"domain_f = {kwargs['domain']!r}", + "nodes = []", + "inputs = []", + "outputs = []", + "atts = []", + ] + return lines + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"inputs.append({kwargs['name']!r})"] + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"outputs.append({kwargs['name']!r})"] + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + atts = kwargs["attributes"] + if isinstance(atts, list) and all(map(lambda t: isinstance(t, str), atts)): + return [f"atts.extend({atts!r})"] + raise NotImplementedError(f"Unable to process function attributes {atts!r}.") + + def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "functions.append(", + " make_function(", + " domain_f, ", + " name_f, ", + " inputs, ", + " outputs, ", + " nodes, ", + " attributes=atts, ", + " opset_imports=opset_imports,", + " )", + ")", + ] + return lines diff --git a/onnx_array_api/light_api/light_emitter.py b/onnx_array_api/light_api/light_emitter.py new file mode 100644 index 0000000..c2925b5 --- /dev/null +++ b/onnx_array_api/light_api/light_emitter.py @@ -0,0 +1,104 @@ +from typing import Any, Dict, List +from .annotations import ELEMENT_TYPE_NAME +from .base_emitter import BaseEmitter + + +class LightEmitter(BaseEmitter): + """ + Converts event into proper code. + """ + + def join(self, rows: List[str], single_line: bool = False) -> str: + "Join the rows" + if single_line: + return ".".join(rows) + return "".join(["(\n ", "\n .".join(rows), "\n)"]) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + opsets = kwargs.get("opsets", {}) + opset = opsets.get("", None) + if opset is not None: + del opsets[""] + args = [] + if opset: + args.append(f"opset={opset}") + if opsets: + args.append(f"opsets={opsets}") + return [f"start({', '.join(args)})"] + + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: + return ["to_onnx()"] + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + value = kwargs["value"] + repl = {"bool": "bool_", "object": "object_", "str": "str_"} + sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + return [ + f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))", + f"rename({name!r})", + ] + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + name = kwargs["name"] + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ] + if elem_type: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" + ] + return [f"vin({name!r})"] + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + inst = [] + if "name" in kwargs: + name = kwargs["name"] + inst.append(f"bring({name!r})") + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + inst.append( + f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ) + elif elem_type: + inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})") + else: + inst.append("vout()") + return inst + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + op_type = kwargs["op_type"] + inputs = kwargs["inputs"] + outputs = kwargs["outputs"] + if kwargs.get("domain", "") != "": + domain = kwargs["domain"] + op_type = f"{domain}.{op_type}" + atts = kwargs.get("atts", {}) + args = [] + for k, v in atts.items(): + before, vatt = self.render_attribute_value(v) + if before: + raise NotImplementedError("Graph attribute not supported yet.") + args.append(f"{k}={vatt}") + + str_inputs = ", ".join([f"{i!r}" for i in inputs]) + inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] + if len(outputs) == 1: + inst.append(f"rename({outputs[0]!r})") + else: + str_outputs = ", ".join([f"{o!r}" for o in outputs]) + inst.append(f"rename({str_outputs})") + return inst diff --git a/onnx_array_api/light_api/make_helper.py b/onnx_array_api/light_api/make_helper.py new file mode 100644 index 0000000..8b2703c --- /dev/null +++ b/onnx_array_api/light_api/make_helper.py @@ -0,0 +1,65 @@ +from typing import Any, Optional, Sequence +from onnx import AttributeProto, NodeProto +from onnx.helper import make_attribute + + +def make_ref_attribute( + key: str, attr_type: int, ref_attr_name: Optional[str] = None +) -> AttributeProto: + """ + Creates an attribute. + + :param key: atttribute name + :param attr_type: attribute type + :param ref_attr_name: if not None, link this attribute + to a function attribute + :return: attribute + """ + att = AttributeProto() + att.name = key + att.type = attr_type + att.ref_attr_name = ref_attr_name + return att + + +def make_node_extended( + op_type: str, + inputs: Sequence[str], + outputs: Sequence[str], + name: Optional[str] = None, + doc_string: Optional[str] = None, + domain: Optional[str] = None, + **kwargs: Any, +) -> NodeProto: + """ + Constructs a NodeProto. + + :param op_type: The name of the operator to construct + :param inputs: list of input names + :param outputs: list of output names + :param name: optional unique identifier for NodeProto + :param doc_string: optional documentation string for NodeProto + :param domain: optional domain for NodeProto. + If it's None, we will just use default domain (which is empty) + :param kwargs: the attributes of the node. + :return: node proto + """ + node = NodeProto() + node.op_type = op_type + node.input.extend(inputs) + node.output.extend(outputs) + if name: + node.name = name + if doc_string: + node.doc_string = doc_string + if domain is not None: + node.domain = domain + if kwargs: + for key, value in sorted(kwargs.items()): + if value is None: + continue + if isinstance(value, AttributeProto): + node.attribute.append(value) + else: + node.attribute.append(make_attribute(key, value)) + return node diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index a61ce24..31c1bce 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -2,7 +2,9 @@ import numpy as np from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto from onnx.numpy_helper import to_array -from .emitter import EventType, Emitter +from ..reference import to_array_extended +from .base_emitter import EventType +from .light_emitter import LightEmitter class Translater: @@ -13,10 +15,10 @@ class Translater: def __init__( self, proto: Union[ModelProto, FunctionProto, GraphProto], - emitter: Optional[Emitter] = None, + emitter: Optional[LightEmitter] = None, ): self.proto_ = proto - self.emitter = emitter or Emitter() + self.emitter = emitter or LightEmitter() def __repr__(self) -> str: return f"{self.__class__.__name__}(<{type(self.proto_)})" @@ -30,6 +32,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: :return: list of instructions """ rows = [] + last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} rows.extend(self.emitter(EventType.START, opsets=opsets)) @@ -38,6 +41,9 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: nodes = self.proto_.graph.node initializers = self.proto_.graph.initializer sparse_initializers = self.proto_.graph.sparse_initializer + attributes = [] + last_event = EventType.TO_ONNX_MODEL + is_function = False elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -48,30 +54,43 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: else: initializers = [] sparse_initializers = [] + attributes = ( + self.proto_.attribute if hasattr(self.proto_, "attribute") else [] + ) + is_function = isinstance(self.proto_, FunctionProto) + last_event = ( + EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL + ) else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") - rows.extend( - self.emitter( - EventType.BEGIN_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.BEGIN_GRAPH + if is_function: + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION, + name=self.proto_.name, + domain=self.proto_.domain, + ) ) - ) + else: + rows.extend(self.emitter(EventType.BEGIN_GRAPH)) for i in initializers: rows.extend( self.emitter( - EventType.INITIALIZER, name=i.name, init=i, value=to_array(i) + EventType.INITIALIZER, + name=i.name, + init=i, + value=to_array_extended(i), ) ) for i in inputs: - if isinstance(i, str): - rows.extend(self.emitter(EventType.INPUT, name=i)) + if is_function: + rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) else: rows.extend( self.emitter( @@ -85,6 +104,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) + if is_function and attributes: + rows.extend( + self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) + ) + for node in nodes: atts = self.extract_attributes(node) rows.extend( @@ -99,8 +123,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) for o in outputs: - if isinstance(o, str): - rows.extend(self.emitter(EventType.INPUT, name=o)) + if is_function: + rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) else: rows.extend( self.emitter( @@ -117,19 +141,21 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: name = self.proto_.name else: name = self.proto_.graph.name + rows.extend( self.emitter( - EventType.END_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.END_GRAPH, + EventType.END_FUNCTION if is_function else EventType.END_GRAPH, name=name, ) ) if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: - raise NotImplementedError("Local functions are not yet implemented.") + for fu in self.proto_.functions: + cl = self.__class__(fu, self.emitter) + text = cl.export(False, single_line=False) + rows.extend(text) - rows.extend(self.emitter(EventType.TO_ONNX)) + rows.extend(self.emitter(last_event)) if as_str: return self.emitter.join(rows, single_line=single_line) return rows