diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index d382b74..1c385ca 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.1.3 +++++ +* :pr:`46`: adds an export to convert an onnx graph into light API code * :pr:`45`: fixes light API for operators with two outputs 0.1.2 diff --git a/README.rst b/README.rst index 035911d..7d53c79 100644 --- a/README.rst +++ b/README.rst @@ -141,4 +141,4 @@ The euclidean distance looks like the following: The library is released on `pypi/onnx-array-api `_ and its documentation is published at -`(Numpy) Array API for ONNX `_. +`APIs to create ONNX Graphs `_. diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 471eb66..a50f050 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -2,33 +2,67 @@ onnx_array_api.light_api ======================== + +Main API +======== + start -===== ++++++ .. autofunction:: onnx_array_api.light_api.start +translate ++++++++++ + +.. autofunction:: onnx_array_api.light_api.translate + +Classes for the Light API +========================= + OnnxGraph -========= ++++++++++ .. autoclass:: onnx_array_api.light_api.OnnxGraph :members: BaseVar -======= ++++++++ .. autoclass:: onnx_array_api.light_api.var.BaseVar :members: Var -=== ++++ .. autoclass:: onnx_array_api.light_api.Var :members: :inherited-members: Vars -==== +++++ .. autoclass:: onnx_array_api.light_api.Vars :members: :inherited-members: + +Classes for the Translater +========================== + +Emitter ++++++++ + +.. autoclass:: onnx_array_api.light_api.translate.Emitter + :members: + +EventType ++++++++++ + +.. autoclass:: onnx_array_api.light_api.translate.EventType + :members: + +Translater +++++++++++ + +.. autoclass:: onnx_array_api.light_api.translate.Translater + :members: + diff --git a/_doc/index.rst b/_doc/index.rst index 52d2cf6..93ca000 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -45,7 +45,8 @@ The objective is to speed up the implementation of converter libraries. CHANGELOGS license -**Numpy API** +Numpy API ++++++++++ Sources available on `github/onnx-array-api `_. @@ -109,7 +110,8 @@ Sources available on res = jitted_myloss(x, y) print(to_dot(jitted_myloss.get_onnx())) -**Light API** +Light API ++++++++++ .. runpython:: :showcode: @@ -135,3 +137,9 @@ Sources available on ) print(onnx_simple_text_plot(model)) + + +Older versions +++++++++++++++ + +* `0.1.2 <../v0.1.2/index.html>`_ diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py new file mode 100644 index 0000000..c1f63f9 --- /dev/null +++ b/_unittests/ut_light_api/test_translate.py @@ -0,0 +1,131 @@ +import unittest +from textwrap import dedent +import numpy as np +from onnx import ModelProto, TensorProto +from onnx.defs import onnx_opset_version +from onnx.reference import ReferenceEvaluator +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.light_api import start, translate + +OPSET_API = min(19, onnx_opset_version() - 1) + + +class TestTranslate(ExtTestCase): + def test_exp(self): + onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Exp", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + code = translate(onx) + expected = dedent( + """ + ( + start(opset=19) + .vin('X', elem_type=TensorProto.FLOAT) + .bring('X') + .Exp() + .rename('Y') + .bring('Y') + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + )""" + ).strip("\n") + self.assertEqual(expected, code) + + onx2 = ( + start(opset=19) + .vin("X", elem_type=TensorProto.FLOAT) + .bring("X") + .Exp() + .rename("Y") + .bring("Y") + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + ) + ref = ReferenceEvaluator(onx2) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(np.exp(a), got) + + def test_transpose(self): + onx = ( + start(opset=19) + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + self.assertIsInstance(onx, ModelProto) + self.assertIn("Transpose", str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + self.assertEqualArray(a.reshape((-1, 1)).T, got) + + code = translate(onx) + expected = dedent( + """ + ( + start(opset=19) + .vin('X', elem_type=TensorProto.FLOAT) + .bring('X', 'r') + .Reshape() + .rename('r0_0') + .bring('r0_0') + .Transpose(perm=[1, 0]) + .rename('Y') + .bring('Y') + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + )""" + ).strip("\n") + self.assertEqual(expected, code) + + def test_topk_reverse(self): + onx = ( + start(opset=19) + .vin("X", np.float32) + .vin("K", np.int64) + .bring("X", "K") + .TopK(largest=0) + .rename("Values", "Indices") + .vout() + .to_onnx() + ) + self.assertIsInstance(onx, ModelProto) + ref = ReferenceEvaluator(onx) + x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32) + k = np.array([2], dtype=np.int64) + got = ref.run(None, {"X": x, "K": k}) + self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0]) + self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1]) + + code = translate(onx) + expected = dedent( + """ + ( + start(opset=19) + .vin('X', elem_type=TensorProto.FLOAT) + .vin('K', elem_type=TensorProto.INT64) + .bring('X', 'K') + .TopK(axis=-1, largest=0, sorted=1) + .rename('Values', 'Indices') + .bring('Values') + .vout(elem_type=TensorProto.FLOAT) + .bring('Indices') + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + )""" + ).strip("\n") + self.assertEqual(expected, code) + + +if __name__ == "__main__": + # TestLightApi().test_topk() + unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 272ea0d..5e549f9 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -1,5 +1,7 @@ from typing import Dict, Optional +from onnx import ModelProto from .model import OnnxGraph +from .translate import Translater from .var import Var, Vars @@ -34,8 +36,48 @@ def start( from onnx_array_api.light_api import start onx = ( - start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx() + start() + .vin("X") + .vin("Y") + .bring("X", "Y") + .Add() + .rename("Z") + .vout() + .to_onnx() ) print(onx) """ return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function) + + +def translate(proto: ModelProto, single_line=False) -> str: + """ + Translates an ONNX proto into a code using :ref:`l-light-api` + to describe the ONNX graph. + + :param proto: model to translate + :param single_line: as a single line or not + :return: code + + .. runpython:: + :showcode: + + from onnx_array_api.light_api import start, translate + + onx = ( + start() + .vin("X") + .reshape((-1, 1)) + .Transpose(perm=[1, 0]) + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx) + print(code) + """ + tr = Translater(proto) + rows = tr.export() + if single_line: + return ".".join(rows) + return "".join(["(\n ", "\n .".join(rows), "\n)"]) diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py index 8d473fd..c975dab 100644 --- a/onnx_array_api/light_api/annotations.py +++ b/onnx_array_api/light_api/annotations.py @@ -12,7 +12,7 @@ ELEMENT_TYPE_NAME = { getattr(TensorProto, k): k for k in dir(TensorProto) - if isinstance(getattr(TensorProto, k), int) + if isinstance(getattr(TensorProto, k), int) and "_" not in k } _type_numpy = { diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py new file mode 100644 index 0000000..db574df --- /dev/null +++ b/onnx_array_api/light_api/translate.py @@ -0,0 +1,260 @@ +from typing import Any, Dict, List, Optional, Tuple, Union +from enum import IntEnum +import numpy as np +from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto +from onnx.numpy_helper import to_array +from .annotations import ELEMENT_TYPE_NAME + + +class EventType(IntEnum): + START = 0 + INPUT = 1 + OUTPUT = 2 + NODE = 3 + TO_ONNX = 4 + + +class Emitter: + """ + Converts event into proper code. + """ + + def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: + """ + Converts an event into an instruction. + + :param event: event kind + :param kwargs: event parameters + :return: list of instructions + """ + if event == EventType.START: + opsets = kwargs.get("opsets", {}) + opset = opsets.get("", None) + if opset is not None: + del opsets[""] + args = [] + if opset: + args.append(f"opset={opset}") + if opsets: + args.append(f"opsets={opsets}") + return [f"start({', '.join(args)})"] + + if event == EventType.TO_ONNX: + return ["to_onnx()"] + + if event == EventType.INPUT: + name = kwargs["name"] + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ] + if elem_type: + return [ + f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" + ] + return [f"vin({name!r})"] + + if event == EventType.OUTPUT: + inst = [] + if "name" in kwargs: + name = kwargs["name"] + inst.append(f"bring({name!r})") + elem_type = kwargs.get("elem_type", None) + shape = kwargs.get("shape", None) + if elem_type and shape: + inst.append( + f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})" + ) + elif elem_type: + inst.append( + f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})" + ) + else: + inst.append("vout()") + return inst + + if event == EventType.NODE: + op_type = kwargs["op_type"] + inputs = kwargs["inputs"] + outputs = kwargs["outputs"] + if kwargs.get("domain", "") != "": + domain = kwargs["domain"] + raise NotImplementedError(f"domain={domain!r} not supported yet.") + atts = kwargs.get("atts", {}) + args = [] + for k, v in atts.items(): + args.append(f"{k}={self.render_attribute_value(v)}") + + str_inputs = ", ".join([f"{i!r}" for i in inputs]) + inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"] + if len(outputs) == 1: + inst.append(f"rename({outputs[0]!r})") + else: + str_outputs = ", ".join([f"{o!r}" for o in outputs]) + inst.append(f"rename({str_outputs})") + return inst + + raise ValueError(f"Unexpected EventType {event}.") + + def render_attribute_value(self, value: Any) -> str: + """ + Renders an attribute value into a string. + """ + v = value[-1] + if isinstance(v, (int, float, list)): + return str(v) + if isinstance(v, np.ndarray): + if len(v.shape) == 0: + return str(v) + if len(v.shape) == 1: + return str(v.tolist()) + raise ValueError(f"Unable to render an attribute {value}.") + + +class Translater: + """ + Translates an ONNX graph into a code following the light API. + """ + + def __init__( + self, + proto: Union[ModelProto, FunctionProto, GraphProto], + emitter: Optional[Emitter] = None, + ): + self.proto_ = proto + self.emit = emitter or Emitter() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(<{type(self.proto_)})" + + def export(self) -> List[str]: + """ + Exports into a code. + + :return: list of instructions + """ + rows = [] + if isinstance(self.proto_, ModelProto): + opsets = {d.domain: d.version for d in self.proto_.opset_import} + rows.extend(self.emit(EventType.START, opsets=opsets)) + inputs = self.proto_.graph.input + outputs = self.proto_.graph.output + nodes = self.proto_.graph.node + elif isinstance(self.proto_, (FunctionProto, GraphProto)): + inputs = self.proto_.input + outputs = self.proto_.output + nodes = self.proto_.node + else: + raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") + + for i in inputs: + if isinstance(i, str): + rows.extend(self.emit(EventType.INPUT, name=i)) + else: + rows.extend( + self.emit( + EventType.INPUT, + name=i.name, + elem_type=i.type.tensor_type.elem_type, + shape=tuple( + d.dim_value or d.dim_param + for d in i.type.tensor_type.shape.dim + ), + ) + ) + + for node in nodes: + atts = self.extract_attributes(node) + rows.extend( + self.emit( + EventType.NODE, + op_type=node.op_type, + inputs=node.input, + outputs=node.output, + domain=node.domain, + atts=atts, + ) + ) + + for o in outputs: + if isinstance(i, str): + rows.extend(self.emit(EventType.INPUT, name=o)) + else: + rows.extend( + self.emit( + EventType.OUTPUT, + name=o.name, + elem_type=o.type.tensor_type.elem_type, + shape=tuple( + d.dim_value or d.dim_param + for d in o.type.tensor_type.shape.dim + ), + ) + ) + + if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: + raise NotImplementedError("Local functions are not yet implemented.") + + rows.extend(self.emit(EventType.TO_ONNX)) + return rows + + def extract_attributes( + self, node: NodeProto + ) -> Dict[str, Tuple[AttributeProto, Any]]: + """ + Extracts all atributes of a node. + + :param node: node proto + :return: dictionary + """ + atts: Dict[str, Tuple[AttributeProto, Any]] = {} + for att in node.attribute: + if hasattr(att, "ref_attr_name") and att.ref_attr_name: + atts[att.name] = (att, None) + continue + if att.type == AttributeProto.INT: + atts[att.name] = (att, att.i) + continue + if att.type == AttributeProto.FLOAT: + atts[att.name] = (att, att.f) + continue + if att.type == AttributeProto.INTS: + atts[att.name] = (att, np.array(att.ints)) + continue + if att.type == AttributeProto.FLOATS: + atts[att.name] = (att, np.array(att.floats, dtype=np.float32)) + continue + if ( + att.type == AttributeProto.GRAPH + and hasattr(att, "g") + and att.g is not None + ): + atts[att.name] = (att, None) + continue + if att.type == AttributeProto.SPARSE_TENSORS: + atts[att.name] = (att, to_array(att.sparse_tensor)) + continue + if att.type == AttributeProto.TENSOR: + atts[att.name] = (att, to_array(att.t)) + continue + if att.type == AttributeProto.TENSORS: + atts[att.name] = (att, [to_array(t) for t in att.tensors]) + continue + if att.type == AttributeProto.SPARSE_TENSORS: + atts[att.name] = (att, [to_array(t) for t in att.sparse_tensors]) + continue + if att.type == AttributeProto.STRING: + atts[att.name] = (att, att.s.decode("utf-8")) + continue + if att.type == AttributeProto.STRINGS: + atts[att.name] = ( + att, + np.array([s.decode("utf-8") for s in att.strings]), + ) + continue + raise ValueError( + f"Attribute {att.name!r} with type {att.type} cannot be extracted yet." + ) + return atts diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index 2c8b375..ddcc7f5 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -128,11 +128,13 @@ def v(self, name: str) -> "Var": """ return self.parent.get_var(name) - def bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": + def bring(self, *vars: List[Union[str, "Var"]]) -> Union["Var", "Vars"]: """ Creates a set of variable as an instance of :class:`onnx_array_api.light_api.Vars`. """ + if len(vars) == 1: + return Var(self.parent, vars[0]) return Vars(self.parent, *vars) def vout(self, **kwargs: Dict[str, Any]) -> Union["Var", "Vars"]: