diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst
index d382b74..1c385ca 100644
--- a/CHANGELOGS.rst
+++ b/CHANGELOGS.rst
@@ -4,6 +4,7 @@ Change Logs
0.1.3
+++++
+* :pr:`46`: adds an export to convert an onnx graph into light API code
* :pr:`45`: fixes light API for operators with two outputs
0.1.2
diff --git a/README.rst b/README.rst
index 035911d..7d53c79 100644
--- a/README.rst
+++ b/README.rst
@@ -141,4 +141,4 @@ The euclidean distance looks like the following:
The library is released on
`pypi/onnx-array-api `_
and its documentation is published at
-`(Numpy) Array API for ONNX `_.
+`APIs to create ONNX Graphs `_.
diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst
index 471eb66..a50f050 100644
--- a/_doc/api/light_api.rst
+++ b/_doc/api/light_api.rst
@@ -2,33 +2,67 @@
onnx_array_api.light_api
========================
+
+Main API
+========
+
start
-=====
++++++
.. autofunction:: onnx_array_api.light_api.start
+translate
++++++++++
+
+.. autofunction:: onnx_array_api.light_api.translate
+
+Classes for the Light API
+=========================
+
OnnxGraph
-=========
++++++++++
.. autoclass:: onnx_array_api.light_api.OnnxGraph
:members:
BaseVar
-=======
++++++++
.. autoclass:: onnx_array_api.light_api.var.BaseVar
:members:
Var
-===
++++
.. autoclass:: onnx_array_api.light_api.Var
:members:
:inherited-members:
Vars
-====
+++++
.. autoclass:: onnx_array_api.light_api.Vars
:members:
:inherited-members:
+
+Classes for the Translater
+==========================
+
+Emitter
++++++++
+
+.. autoclass:: onnx_array_api.light_api.translate.Emitter
+ :members:
+
+EventType
++++++++++
+
+.. autoclass:: onnx_array_api.light_api.translate.EventType
+ :members:
+
+Translater
+++++++++++
+
+.. autoclass:: onnx_array_api.light_api.translate.Translater
+ :members:
+
diff --git a/_doc/index.rst b/_doc/index.rst
index 52d2cf6..93ca000 100644
--- a/_doc/index.rst
+++ b/_doc/index.rst
@@ -45,7 +45,8 @@ The objective is to speed up the implementation of converter libraries.
CHANGELOGS
license
-**Numpy API**
+Numpy API
++++++++++
Sources available on
`github/onnx-array-api `_.
@@ -109,7 +110,8 @@ Sources available on
res = jitted_myloss(x, y)
print(to_dot(jitted_myloss.get_onnx()))
-**Light API**
+Light API
++++++++++
.. runpython::
:showcode:
@@ -135,3 +137,9 @@ Sources available on
)
print(onnx_simple_text_plot(model))
+
+
+Older versions
+++++++++++++++
+
+* `0.1.2 <../v0.1.2/index.html>`_
diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py
new file mode 100644
index 0000000..c1f63f9
--- /dev/null
+++ b/_unittests/ut_light_api/test_translate.py
@@ -0,0 +1,131 @@
+import unittest
+from textwrap import dedent
+import numpy as np
+from onnx import ModelProto, TensorProto
+from onnx.defs import onnx_opset_version
+from onnx.reference import ReferenceEvaluator
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.light_api import start, translate
+
+OPSET_API = min(19, onnx_opset_version() - 1)
+
+
+class TestTranslate(ExtTestCase):
+ def test_exp(self):
+ onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Exp", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ code = translate(onx)
+ expected = dedent(
+ """
+ (
+ start(opset=19)
+ .vin('X', elem_type=TensorProto.FLOAT)
+ .bring('X')
+ .Exp()
+ .rename('Y')
+ .bring('Y')
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )"""
+ ).strip("\n")
+ self.assertEqual(expected, code)
+
+ onx2 = (
+ start(opset=19)
+ .vin("X", elem_type=TensorProto.FLOAT)
+ .bring("X")
+ .Exp()
+ .rename("Y")
+ .bring("Y")
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )
+ ref = ReferenceEvaluator(onx2)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ def test_transpose(self):
+ onx = (
+ start(opset=19)
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Transpose", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(a.reshape((-1, 1)).T, got)
+
+ code = translate(onx)
+ expected = dedent(
+ """
+ (
+ start(opset=19)
+ .vin('X', elem_type=TensorProto.FLOAT)
+ .bring('X', 'r')
+ .Reshape()
+ .rename('r0_0')
+ .bring('r0_0')
+ .Transpose(perm=[1, 0])
+ .rename('Y')
+ .bring('Y')
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )"""
+ ).strip("\n")
+ self.assertEqual(expected, code)
+
+ def test_topk_reverse(self):
+ onx = (
+ start(opset=19)
+ .vin("X", np.float32)
+ .vin("K", np.int64)
+ .bring("X", "K")
+ .TopK(largest=0)
+ .rename("Values", "Indices")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
+ k = np.array([2], dtype=np.int64)
+ got = ref.run(None, {"X": x, "K": k})
+ self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
+ self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])
+
+ code = translate(onx)
+ expected = dedent(
+ """
+ (
+ start(opset=19)
+ .vin('X', elem_type=TensorProto.FLOAT)
+ .vin('K', elem_type=TensorProto.INT64)
+ .bring('X', 'K')
+ .TopK(axis=-1, largest=0, sorted=1)
+ .rename('Values', 'Indices')
+ .bring('Values')
+ .vout(elem_type=TensorProto.FLOAT)
+ .bring('Indices')
+ .vout(elem_type=TensorProto.FLOAT)
+ .to_onnx()
+ )"""
+ ).strip("\n")
+ self.assertEqual(expected, code)
+
+
+if __name__ == "__main__":
+ # TestLightApi().test_topk()
+ unittest.main(verbosity=2)
diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py
index 272ea0d..5e549f9 100644
--- a/onnx_array_api/light_api/__init__.py
+++ b/onnx_array_api/light_api/__init__.py
@@ -1,5 +1,7 @@
from typing import Dict, Optional
+from onnx import ModelProto
from .model import OnnxGraph
+from .translate import Translater
from .var import Var, Vars
@@ -34,8 +36,48 @@ def start(
from onnx_array_api.light_api import start
onx = (
- start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Add()
+ .rename("Z")
+ .vout()
+ .to_onnx()
)
print(onx)
"""
return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
+
+
+def translate(proto: ModelProto, single_line=False) -> str:
+ """
+ Translates an ONNX proto into a code using :ref:`l-light-api`
+ to describe the ONNX graph.
+
+ :param proto: model to translate
+ :param single_line: as a single line or not
+ :return: code
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start, translate
+
+ onx = (
+ start()
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ code = translate(onx)
+ print(code)
+ """
+ tr = Translater(proto)
+ rows = tr.export()
+ if single_line:
+ return ".".join(rows)
+ return "".join(["(\n ", "\n .".join(rows), "\n)"])
diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py
index 8d473fd..c975dab 100644
--- a/onnx_array_api/light_api/annotations.py
+++ b/onnx_array_api/light_api/annotations.py
@@ -12,7 +12,7 @@
ELEMENT_TYPE_NAME = {
getattr(TensorProto, k): k
for k in dir(TensorProto)
- if isinstance(getattr(TensorProto, k), int)
+ if isinstance(getattr(TensorProto, k), int) and "_" not in k
}
_type_numpy = {
diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py
new file mode 100644
index 0000000..db574df
--- /dev/null
+++ b/onnx_array_api/light_api/translate.py
@@ -0,0 +1,260 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+from enum import IntEnum
+import numpy as np
+from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto
+from onnx.numpy_helper import to_array
+from .annotations import ELEMENT_TYPE_NAME
+
+
+class EventType(IntEnum):
+ START = 0
+ INPUT = 1
+ OUTPUT = 2
+ NODE = 3
+ TO_ONNX = 4
+
+
+class Emitter:
+ """
+ Converts event into proper code.
+ """
+
+ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
+ """
+ Converts an event into an instruction.
+
+ :param event: event kind
+ :param kwargs: event parameters
+ :return: list of instructions
+ """
+ if event == EventType.START:
+ opsets = kwargs.get("opsets", {})
+ opset = opsets.get("", None)
+ if opset is not None:
+ del opsets[""]
+ args = []
+ if opset:
+ args.append(f"opset={opset}")
+ if opsets:
+ args.append(f"opsets={opsets}")
+ return [f"start({', '.join(args)})"]
+
+ if event == EventType.TO_ONNX:
+ return ["to_onnx()"]
+
+ if event == EventType.INPUT:
+ name = kwargs["name"]
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})"
+ ]
+ if elem_type:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
+ ]
+ return [f"vin({name!r})"]
+
+ if event == EventType.OUTPUT:
+ inst = []
+ if "name" in kwargs:
+ name = kwargs["name"]
+ inst.append(f"bring({name!r})")
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ inst.append(
+ f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})"
+ )
+ elif elem_type:
+ inst.append(
+ f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
+ )
+ else:
+ inst.append("vout()")
+ return inst
+
+ if event == EventType.NODE:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ if kwargs.get("domain", "") != "":
+ domain = kwargs["domain"]
+ raise NotImplementedError(f"domain={domain!r} not supported yet.")
+ atts = kwargs.get("atts", {})
+ args = []
+ for k, v in atts.items():
+ args.append(f"{k}={self.render_attribute_value(v)}")
+
+ str_inputs = ", ".join([f"{i!r}" for i in inputs])
+ inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"]
+ if len(outputs) == 1:
+ inst.append(f"rename({outputs[0]!r})")
+ else:
+ str_outputs = ", ".join([f"{o!r}" for o in outputs])
+ inst.append(f"rename({str_outputs})")
+ return inst
+
+ raise ValueError(f"Unexpected EventType {event}.")
+
+ def render_attribute_value(self, value: Any) -> str:
+ """
+ Renders an attribute value into a string.
+ """
+ v = value[-1]
+ if isinstance(v, (int, float, list)):
+ return str(v)
+ if isinstance(v, np.ndarray):
+ if len(v.shape) == 0:
+ return str(v)
+ if len(v.shape) == 1:
+ return str(v.tolist())
+ raise ValueError(f"Unable to render an attribute {value}.")
+
+
+class Translater:
+ """
+ Translates an ONNX graph into a code following the light API.
+ """
+
+ def __init__(
+ self,
+ proto: Union[ModelProto, FunctionProto, GraphProto],
+ emitter: Optional[Emitter] = None,
+ ):
+ self.proto_ = proto
+ self.emit = emitter or Emitter()
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(<{type(self.proto_)})"
+
+ def export(self) -> List[str]:
+ """
+ Exports into a code.
+
+ :return: list of instructions
+ """
+ rows = []
+ if isinstance(self.proto_, ModelProto):
+ opsets = {d.domain: d.version for d in self.proto_.opset_import}
+ rows.extend(self.emit(EventType.START, opsets=opsets))
+ inputs = self.proto_.graph.input
+ outputs = self.proto_.graph.output
+ nodes = self.proto_.graph.node
+ elif isinstance(self.proto_, (FunctionProto, GraphProto)):
+ inputs = self.proto_.input
+ outputs = self.proto_.output
+ nodes = self.proto_.node
+ else:
+ raise ValueError(f"Unexpected type {type(self.proto_)} for proto.")
+
+ for i in inputs:
+ if isinstance(i, str):
+ rows.extend(self.emit(EventType.INPUT, name=i))
+ else:
+ rows.extend(
+ self.emit(
+ EventType.INPUT,
+ name=i.name,
+ elem_type=i.type.tensor_type.elem_type,
+ shape=tuple(
+ d.dim_value or d.dim_param
+ for d in i.type.tensor_type.shape.dim
+ ),
+ )
+ )
+
+ for node in nodes:
+ atts = self.extract_attributes(node)
+ rows.extend(
+ self.emit(
+ EventType.NODE,
+ op_type=node.op_type,
+ inputs=node.input,
+ outputs=node.output,
+ domain=node.domain,
+ atts=atts,
+ )
+ )
+
+ for o in outputs:
+ if isinstance(i, str):
+ rows.extend(self.emit(EventType.INPUT, name=o))
+ else:
+ rows.extend(
+ self.emit(
+ EventType.OUTPUT,
+ name=o.name,
+ elem_type=o.type.tensor_type.elem_type,
+ shape=tuple(
+ d.dim_value or d.dim_param
+ for d in o.type.tensor_type.shape.dim
+ ),
+ )
+ )
+
+ if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0:
+ raise NotImplementedError("Local functions are not yet implemented.")
+
+ rows.extend(self.emit(EventType.TO_ONNX))
+ return rows
+
+ def extract_attributes(
+ self, node: NodeProto
+ ) -> Dict[str, Tuple[AttributeProto, Any]]:
+ """
+ Extracts all atributes of a node.
+
+ :param node: node proto
+ :return: dictionary
+ """
+ atts: Dict[str, Tuple[AttributeProto, Any]] = {}
+ for att in node.attribute:
+ if hasattr(att, "ref_attr_name") and att.ref_attr_name:
+ atts[att.name] = (att, None)
+ continue
+ if att.type == AttributeProto.INT:
+ atts[att.name] = (att, att.i)
+ continue
+ if att.type == AttributeProto.FLOAT:
+ atts[att.name] = (att, att.f)
+ continue
+ if att.type == AttributeProto.INTS:
+ atts[att.name] = (att, np.array(att.ints))
+ continue
+ if att.type == AttributeProto.FLOATS:
+ atts[att.name] = (att, np.array(att.floats, dtype=np.float32))
+ continue
+ if (
+ att.type == AttributeProto.GRAPH
+ and hasattr(att, "g")
+ and att.g is not None
+ ):
+ atts[att.name] = (att, None)
+ continue
+ if att.type == AttributeProto.SPARSE_TENSORS:
+ atts[att.name] = (att, to_array(att.sparse_tensor))
+ continue
+ if att.type == AttributeProto.TENSOR:
+ atts[att.name] = (att, to_array(att.t))
+ continue
+ if att.type == AttributeProto.TENSORS:
+ atts[att.name] = (att, [to_array(t) for t in att.tensors])
+ continue
+ if att.type == AttributeProto.SPARSE_TENSORS:
+ atts[att.name] = (att, [to_array(t) for t in att.sparse_tensors])
+ continue
+ if att.type == AttributeProto.STRING:
+ atts[att.name] = (att, att.s.decode("utf-8"))
+ continue
+ if att.type == AttributeProto.STRINGS:
+ atts[att.name] = (
+ att,
+ np.array([s.decode("utf-8") for s in att.strings]),
+ )
+ continue
+ raise ValueError(
+ f"Attribute {att.name!r} with type {att.type} cannot be extracted yet."
+ )
+ return atts
diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py
index 2c8b375..ddcc7f5 100644
--- a/onnx_array_api/light_api/var.py
+++ b/onnx_array_api/light_api/var.py
@@ -128,11 +128,13 @@ def v(self, name: str) -> "Var":
"""
return self.parent.get_var(name)
- def bring(self, *vars: List[Union[str, "Var"]]) -> "Vars":
+ def bring(self, *vars: List[Union[str, "Var"]]) -> Union["Var", "Vars"]:
"""
Creates a set of variable as an instance of
:class:`onnx_array_api.light_api.Vars`.
"""
+ if len(vars) == 1:
+ return Var(self.parent, vars[0])
return Vars(self.parent, *vars)
def vout(self, **kwargs: Dict[str, Any]) -> Union["Var", "Vars"]: