Skip to content

Commit 75d62a0

Browse files
authored
Add an export to convert an onnx graph into light API code (#46)
* Add an export to convert an onnx graph into light API code * fix unit tests * fix annotations * fix documentation * doc
1 parent dd11424 commit 75d62a0

File tree

9 files changed

+489
-11
lines changed

9 files changed

+489
-11
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.1.3
55
+++++
66

7+
* :pr:`46`: adds an export to convert an onnx graph into light API code
78
* :pr:`45`: fixes light API for operators with two outputs
89

910
0.1.2

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,4 @@ The euclidean distance looks like the following:
141141
The library is released on
142142
`pypi/onnx-array-api <https://pypi.org/project/onnx-array-api/>`_
143143
and its documentation is published at
144-
`(Numpy) Array API for ONNX <https://sdpython.github.io/doc/onnx-array-api/dev/>`_.
144+
`APIs to create ONNX Graphs <https://sdpython.github.io/doc/onnx-array-api/dev/>`_.

_doc/api/light_api.rst

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,67 @@
22
onnx_array_api.light_api
33
========================
44

5+
6+
Main API
7+
========
8+
59
start
6-
=====
10+
+++++
711

812
.. autofunction:: onnx_array_api.light_api.start
913

14+
translate
15+
+++++++++
16+
17+
.. autofunction:: onnx_array_api.light_api.translate
18+
19+
Classes for the Light API
20+
=========================
21+
1022
OnnxGraph
11-
=========
23+
+++++++++
1224

1325
.. autoclass:: onnx_array_api.light_api.OnnxGraph
1426
:members:
1527

1628
BaseVar
17-
=======
29+
+++++++
1830

1931
.. autoclass:: onnx_array_api.light_api.var.BaseVar
2032
:members:
2133

2234
Var
23-
===
35+
+++
2436

2537
.. autoclass:: onnx_array_api.light_api.Var
2638
:members:
2739
:inherited-members:
2840

2941
Vars
30-
====
42+
++++
3143

3244
.. autoclass:: onnx_array_api.light_api.Vars
3345
:members:
3446
:inherited-members:
47+
48+
Classes for the Translater
49+
==========================
50+
51+
Emitter
52+
+++++++
53+
54+
.. autoclass:: onnx_array_api.light_api.translate.Emitter
55+
:members:
56+
57+
EventType
58+
+++++++++
59+
60+
.. autoclass:: onnx_array_api.light_api.translate.EventType
61+
:members:
62+
63+
Translater
64+
++++++++++
65+
66+
.. autoclass:: onnx_array_api.light_api.translate.Translater
67+
:members:
68+

_doc/index.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ The objective is to speed up the implementation of converter libraries.
4545
CHANGELOGS
4646
license
4747

48-
**Numpy API**
48+
Numpy API
49+
+++++++++
4950

5051
Sources available on
5152
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_.
@@ -109,7 +110,8 @@ Sources available on
109110
res = jitted_myloss(x, y)
110111
print(to_dot(jitted_myloss.get_onnx()))
111112

112-
**Light API**
113+
Light API
114+
+++++++++
113115

114116
.. runpython::
115117
:showcode:
@@ -135,3 +137,9 @@ Sources available on
135137
)
136138

137139
print(onnx_simple_text_plot(model))
140+
141+
142+
Older versions
143+
++++++++++++++
144+
145+
* `0.1.2 <../v0.1.2/index.html>`_
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import unittest
2+
from textwrap import dedent
3+
import numpy as np
4+
from onnx import ModelProto, TensorProto
5+
from onnx.defs import onnx_opset_version
6+
from onnx.reference import ReferenceEvaluator
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.light_api import start, translate
9+
10+
OPSET_API = min(19, onnx_opset_version() - 1)
11+
12+
13+
class TestTranslate(ExtTestCase):
14+
def test_exp(self):
15+
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
16+
self.assertIsInstance(onx, ModelProto)
17+
self.assertIn("Exp", str(onx))
18+
ref = ReferenceEvaluator(onx)
19+
a = np.arange(10).astype(np.float32)
20+
got = ref.run(None, {"X": a})[0]
21+
self.assertEqualArray(np.exp(a), got)
22+
23+
code = translate(onx)
24+
expected = dedent(
25+
"""
26+
(
27+
start(opset=19)
28+
.vin('X', elem_type=TensorProto.FLOAT)
29+
.bring('X')
30+
.Exp()
31+
.rename('Y')
32+
.bring('Y')
33+
.vout(elem_type=TensorProto.FLOAT)
34+
.to_onnx()
35+
)"""
36+
).strip("\n")
37+
self.assertEqual(expected, code)
38+
39+
onx2 = (
40+
start(opset=19)
41+
.vin("X", elem_type=TensorProto.FLOAT)
42+
.bring("X")
43+
.Exp()
44+
.rename("Y")
45+
.bring("Y")
46+
.vout(elem_type=TensorProto.FLOAT)
47+
.to_onnx()
48+
)
49+
ref = ReferenceEvaluator(onx2)
50+
a = np.arange(10).astype(np.float32)
51+
got = ref.run(None, {"X": a})[0]
52+
self.assertEqualArray(np.exp(a), got)
53+
54+
def test_transpose(self):
55+
onx = (
56+
start(opset=19)
57+
.vin("X")
58+
.reshape((-1, 1))
59+
.Transpose(perm=[1, 0])
60+
.rename("Y")
61+
.vout()
62+
.to_onnx()
63+
)
64+
self.assertIsInstance(onx, ModelProto)
65+
self.assertIn("Transpose", str(onx))
66+
ref = ReferenceEvaluator(onx)
67+
a = np.arange(10).astype(np.float32)
68+
got = ref.run(None, {"X": a})[0]
69+
self.assertEqualArray(a.reshape((-1, 1)).T, got)
70+
71+
code = translate(onx)
72+
expected = dedent(
73+
"""
74+
(
75+
start(opset=19)
76+
.vin('X', elem_type=TensorProto.FLOAT)
77+
.bring('X', 'r')
78+
.Reshape()
79+
.rename('r0_0')
80+
.bring('r0_0')
81+
.Transpose(perm=[1, 0])
82+
.rename('Y')
83+
.bring('Y')
84+
.vout(elem_type=TensorProto.FLOAT)
85+
.to_onnx()
86+
)"""
87+
).strip("\n")
88+
self.assertEqual(expected, code)
89+
90+
def test_topk_reverse(self):
91+
onx = (
92+
start(opset=19)
93+
.vin("X", np.float32)
94+
.vin("K", np.int64)
95+
.bring("X", "K")
96+
.TopK(largest=0)
97+
.rename("Values", "Indices")
98+
.vout()
99+
.to_onnx()
100+
)
101+
self.assertIsInstance(onx, ModelProto)
102+
ref = ReferenceEvaluator(onx)
103+
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
104+
k = np.array([2], dtype=np.int64)
105+
got = ref.run(None, {"X": x, "K": k})
106+
self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
107+
self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])
108+
109+
code = translate(onx)
110+
expected = dedent(
111+
"""
112+
(
113+
start(opset=19)
114+
.vin('X', elem_type=TensorProto.FLOAT)
115+
.vin('K', elem_type=TensorProto.INT64)
116+
.bring('X', 'K')
117+
.TopK(axis=-1, largest=0, sorted=1)
118+
.rename('Values', 'Indices')
119+
.bring('Values')
120+
.vout(elem_type=TensorProto.FLOAT)
121+
.bring('Indices')
122+
.vout(elem_type=TensorProto.FLOAT)
123+
.to_onnx()
124+
)"""
125+
).strip("\n")
126+
self.assertEqual(expected, code)
127+
128+
129+
if __name__ == "__main__":
130+
# TestLightApi().test_topk()
131+
unittest.main(verbosity=2)

onnx_array_api/light_api/__init__.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Dict, Optional
2+
from onnx import ModelProto
23
from .model import OnnxGraph
4+
from .translate import Translater
35
from .var import Var, Vars
46

57

@@ -34,8 +36,48 @@ def start(
3436
from onnx_array_api.light_api import start
3537
3638
onx = (
37-
start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
39+
start()
40+
.vin("X")
41+
.vin("Y")
42+
.bring("X", "Y")
43+
.Add()
44+
.rename("Z")
45+
.vout()
46+
.to_onnx()
3847
)
3948
print(onx)
4049
"""
4150
return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
51+
52+
53+
def translate(proto: ModelProto, single_line=False) -> str:
54+
"""
55+
Translates an ONNX proto into a code using :ref:`l-light-api`
56+
to describe the ONNX graph.
57+
58+
:param proto: model to translate
59+
:param single_line: as a single line or not
60+
:return: code
61+
62+
.. runpython::
63+
:showcode:
64+
65+
from onnx_array_api.light_api import start, translate
66+
67+
onx = (
68+
start()
69+
.vin("X")
70+
.reshape((-1, 1))
71+
.Transpose(perm=[1, 0])
72+
.rename("Y")
73+
.vout()
74+
.to_onnx()
75+
)
76+
code = translate(onx)
77+
print(code)
78+
"""
79+
tr = Translater(proto)
80+
rows = tr.export()
81+
if single_line:
82+
return ".".join(rows)
83+
return "".join(["(\n ", "\n .".join(rows), "\n)"])

onnx_array_api/light_api/annotations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ELEMENT_TYPE_NAME = {
1313
getattr(TensorProto, k): k
1414
for k in dir(TensorProto)
15-
if isinstance(getattr(TensorProto, k), int)
15+
if isinstance(getattr(TensorProto, k), int) and "_" not in k
1616
}
1717

1818
_type_numpy = {

0 commit comments

Comments
 (0)