Skip to content

First draft to export to GraphBuilder #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
Change Logs
===========

0.2.0
0.3.0
+++++

* :pr:`79`: first draft to export to GraphBuilder
* :pr:`77`: supports ConcatOfShape and Slice with the light API

0.2.0
+++++

* :pr:`76`, :pr:`79`: add a mode to compare models without execution
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_translate_api/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,4 @@ def test_aionnxml(self):


if __name__ == "__main__":
TestTranslate().test_export_if()
unittest.main(verbosity=2)
122 changes: 122 additions & 0 deletions _unittests/ut_translate_api/test_translate_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import unittest
from textwrap import dedent
import numpy as np
from onnx import ModelProto, TensorProto
from onnx.checker import check_model
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.light_api import start
from onnx_array_api.graph_api import GraphBuilder
from onnx_array_api.translate_api import translate


OPSET_API = min(19, onnx_opset_version() - 1)


class TestTranslateBuilder(ExtTestCase):
def setUp(self):
self.maxDiff = None

def test_exp(self):
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
self.assertIn("Exp", str(onx))
ref = ReferenceEvaluator(onx)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(np.exp(a), got)

code = translate(onx, api="builder")
expected = dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
Y = op.Exp(X)
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19})
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
"""
).strip("\n")
self.assertEqual(expected, code.strip("\n"))

def light_api(
op: "GraphBuilder",
X: "FLOAT[]", # noqa: F722
):
Y = op.Exp(X)
op.Identity(Y, outputs=["Y"])
return Y

g2 = GraphBuilder({"": 19})
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
light_api(g2.op, "X")
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
onx2 = g2.to_onnx()

ref = ReferenceEvaluator(onx2)
a = np.arange(10).astype(np.float32)
got = ref.run(None, {"X": a})[0]
self.assertEqualArray(np.exp(a), got)

def test_zdoc(self):
onx = (
start(opset=19)
.vin("X")
.reshape((-1, 1))
.Transpose(perm=[1, 0])
.rename("Y")
.vout()
.to_onnx()
)
code = translate(onx, api="builder")
expected = dedent(
"""
def light_api(
op: "GraphBuilder",
X: "FLOAT[]",
):
r = np.array([-1, 1], dtype=np.int64)
r0_0 = op.Reshape(X, r)
Y = op.Transpose(r0_0, perm=[1, 0])
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({'': 19})
g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, "X")
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
"""
).strip("\n")
self.maxDiff = None
self.assertEqual(expected, code.strip("\n"))

def light_api(
op: "GraphBuilder",
X: "FLOAT[]", # noqa: F722
):
r = np.array([-1, 1], dtype=np.int64)
r0_0 = op.Reshape(X, r)
Y = op.Transpose(r0_0, perm=[1, 0])
op.Identity(Y, outputs=["Y"])
return Y

g = GraphBuilder({"": 21})
X = g.make_tensor_input("X", TensorProto.FLOAT, ())
light_api(g.op, X)
g.make_tensor_output("Y", TensorProto.FLOAT, ())
model = g.to_onnx()
self.assertNotEmpty(model)
check_model(model)


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 12 additions & 0 deletions onnx_array_api/graph_api/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ def __getattr__(self, name):
except AttributeError as e:
raise AttributeError(f"Unable to access attribute {name!r}.") from e

def Initializer(
self, init: Union[TensorProto, np.ndarray], name: Optional[str] = None
) -> str:
"""
Creates an initializer.

:param init: value
:param name: name if value is not a TensorProto
:return: its name
"""
return self.builder.make_initializer(init, name=name, exists=True)

def make_node(
self,
op_type: str,
Expand Down
30 changes: 28 additions & 2 deletions onnx_array_api/translate_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from onnx import ModelProto
from .translate import Translater
from .inner_emitter import InnerEmitter
from .builder_emitter import BuilderEmitter


def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
Expand All @@ -14,7 +15,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
default is `"light"` and this is handle by class
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
another value is `"onnx"` which is the inner API implemented
in onnx package.
in onnx package, `"builder"` follows the syntax for the
class :class:`onnx_array_api.graph_api.GraphBuilder`
:return: code

.. runpython::
Expand All @@ -35,7 +37,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
code = translate(onx)
print(code)

The inner API from onnx packahe is also available.
The inner API from onnx package is also available.

.. runpython::
:showcode:
Expand All @@ -54,11 +56,35 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
)
code = translate(onx, api="onnx")
print(code)

The :class:`GraphBuilder
<onnx_array_api.graph_api.GraphBuilder>` API returns this:

.. runpython::
:showcode:

from onnx_array_api.light_api import start
from onnx_array_api.translate_api import translate

onx = (
start()
.vin("X")
.reshape((-1, 1))
.Transpose(perm=[1, 0])
.rename("Y")
.vout()
.to_onnx()
)
code = translate(onx, api="builder")
print(code)
"""
if api == "light":
tr = Translater(proto)
return tr.export(single_line=single_line, as_str=True)
if api == "onnx":
tr = Translater(proto, emitter=InnerEmitter())
return tr.export(as_str=True)
if api == "builder":
tr = Translater(proto, emitter=BuilderEmitter())
return tr.export(as_str=True)
raise ValueError(f"Unexpected value {api!r} for api.")
28 changes: 28 additions & 0 deletions onnx_array_api/translate_api/base_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class EventType(IntEnum):
FUNCTION_OUTPUT = 12
FUNCTION_ATTRIBUTES = 13
TO_ONNX_FUNCTION = 14
BEGIN_SIGNATURE = 15
END_SIGNATURE = 16
BEGIN_RETURN = 17
END_RETURN = 18

@classmethod
def to_str(cls, self) -> str:
Expand Down Expand Up @@ -84,6 +88,18 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
if event == EventType.FUNCTION_ATTRIBUTES:
return self._emit_function_attributes(**kwargs)

if event == EventType.BEGIN_SIGNATURE:
return self._emit_begin_signature(**kwargs)

if event == EventType.END_SIGNATURE:
return self._emit_end_signature(**kwargs)

if event == EventType.BEGIN_RETURN:
return self._emit_begin_return(**kwargs)

if event == EventType.END_RETURN:
return self._emit_end_return(**kwargs)

raise ValueError(f"Unexpected event {EventType.to_str(event)}.")

def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
Expand Down Expand Up @@ -222,3 +238,15 @@ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)

def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
return []

def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
Loading