Skip to content

Commit 9351aca

Browse files
committed
ch
1 parent e29df50 commit 9351aca

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.3.1
55
+++++
66

7-
* :pr:`94`: improves translation to GraphBuilder
7+
* :pr:`95`: improves translation to GraphBuilder
88

99
0.3.0
1010
+++++

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from onnx_array_api.ext_test_case import ExtTestCase
99
from onnx_array_api.light_api import start
1010
from onnx_array_api.graph_api import GraphBuilder
11-
from onnx_array_api.translate_api import translate
11+
from onnx_array_api.translate_api import translate, Translater
12+
from onnx_array_api.translate_api.builder_emitter import BuilderEmitter
1213

1314

1415
OPSET_API = min(19, onnx_opset_version() - 1)
@@ -38,7 +39,7 @@ def light_api(
3839
op.Identity(Y, outputs=["Y"])
3940
return Y
4041
41-
g = GraphBuilder({'': 19})
42+
g = GraphBuilder({'': 19}, ir_version=11)
4243
g.make_tensor_input("X", TensorProto.FLOAT, ())
4344
light_api(g.op, "X")
4445
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -89,7 +90,7 @@ def light_api(
8990
op.Identity(Y, outputs=["Y"])
9091
return Y
9192
92-
g = GraphBuilder({'': 19})
93+
g = GraphBuilder({'': 19}, ir_version=11)
9394
g.make_tensor_input("X", TensorProto.FLOAT, ())
9495
light_api(g.op, "X")
9596
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -117,6 +118,62 @@ def light_api(
117118
self.assertNotEmpty(model)
118119
check_model(model)
119120

121+
def test_exp_f(self):
122+
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
123+
self.assertIsInstance(onx, ModelProto)
124+
self.assertIn("Exp", str(onx))
125+
ref = ReferenceEvaluator(onx)
126+
a = np.arange(10).astype(np.float32)
127+
got = ref.run(None, {"X": a})[0]
128+
self.assertEqualArray(np.exp(a), got)
129+
130+
tr = Translater(onx, emitter=BuilderEmitter("mm"))
131+
code = tr.export(as_str=True)
132+
133+
expected = dedent(
134+
"""
135+
def light_api(
136+
op: "GraphBuilder",
137+
X: "FLOAT[]",
138+
):
139+
Y = op.Exp(X)
140+
op.Identity(Y, outputs=["Y"])
141+
return Y
142+
143+
144+
def mm() -> "ModelProto":
145+
g = GraphBuilder({'': 19}, ir_version=11)
146+
g.make_tensor_input("X", TensorProto.FLOAT, ())
147+
light_api(g.op, "X")
148+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
149+
model = g.to_onnx()
150+
return model
151+
152+
153+
model = mm()
154+
"""
155+
).strip("\n")
156+
self.assertEqual(expected, code.strip("\n"))
157+
158+
def light_api(
159+
op: "GraphBuilder",
160+
X: "FLOAT[]", # noqa: F722
161+
):
162+
Y = op.Exp(X)
163+
op.Identity(Y, outputs=["Y"])
164+
return Y
165+
166+
g2 = GraphBuilder({"": 19})
167+
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
168+
light_api(g2.op, "X")
169+
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
170+
onx2 = g2.to_onnx()
171+
172+
ref = ReferenceEvaluator(onx2)
173+
a = np.arange(10).astype(np.float32)
174+
got = ref.run(None, {"X": a})[0]
175+
self.assertEqualArray(np.exp(a), got)
176+
120177

121178
if __name__ == "__main__":
122179
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)