Skip to content

Commit f513e4b

Browse files
committed
more coverge
1 parent 08a3cf3 commit f513e4b

File tree

3 files changed

+333
-136
lines changed

3 files changed

+333
-136
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import contextlib
2+
import io
3+
import unittest
4+
import numpy as np
5+
import onnx
6+
from onnx.reference import ReferenceEvaluator
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.graph_api.graph_builder import GraphBuilder
9+
10+
11+
class TestGraphBuilder(ExtTestCase):
12+
def call_optimizer(self, onx):
13+
gr = GraphBuilder(onx)
14+
gr.remove_unused()
15+
return gr.to_onnx()
16+
17+
def test_remove_unused_nodes(self):
18+
model = onnx.parser.parse_model(
19+
"""
20+
<ir_version: 8, opset_import: [ "": 18]>
21+
agraph (float[N] x) => (float[N] z) {
22+
two = Constant <value_float=2.0> ()
23+
four = Add(two, two)
24+
z = Mul(x, x)
25+
}"""
26+
)
27+
onx = self.call_optimizer(model)
28+
self.assertEqual(len(onx.graph.node), 1)
29+
self.assertEqual(onx.graph.node[0].op_type, "Mul")
30+
31+
def test_initializers(self):
32+
model = onnx.parser.parse_model(
33+
"""
34+
<ir_version: 8, opset_import: [ "": 18]>
35+
agraph (float[N] x) => (float[N] z)
36+
<float two = {2.0}> {
37+
four = Add(two, two)
38+
z = Mul(x, x)
39+
}"""
40+
)
41+
self.assertEqual(len(model.graph.initializer), 1)
42+
onx = self.call_optimizer(model)
43+
self.assertEqual(len(onx.graph.node), 1)
44+
self.assertEqual(onx.graph.node[0].op_type, "Mul")
45+
self.assertEqual(len(onx.graph.initializer), 0)
46+
47+
def test_keep_unused_outputs(self):
48+
model = onnx.parser.parse_model(
49+
"""
50+
<ir_version: 8, opset_import: [ "": 18]>
51+
agraph (float[N] x) => (float[M] z) {
52+
w1, w2, w3 = Split (x)
53+
z = Mul(w3, w3)
54+
}"""
55+
)
56+
onx = self.call_optimizer(model)
57+
self.assertEqual(len(onx.graph.node), 2)
58+
self.assertEqual(onx.graph.node[0].op_type, "Split")
59+
60+
def test_exc(self):
61+
self.assertRaise(lambda: GraphBuilder([]), NotImplementedError)
62+
63+
def test_simple(self):
64+
with contextlib.redirect_stdout(io.StringIO()):
65+
g = GraphBuilder(verbose=10)
66+
67+
shape = (10, 4)
68+
w = np.random.randn(*shape).astype(np.float32)
69+
70+
x = g.make_tensor_input("X", np.float32, shape)
71+
weight = g.make_initializer(w)
72+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
73+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
74+
res = g.op.MatMul(x, transposed)
75+
g.op.Reshape(res, one, outputs="y")
76+
g.make_tensor_output("y", np.float32, (10, 1))
77+
onx = g.to_onnx()
78+
ref = ReferenceEvaluator(onx)
79+
x = np.random.randn(*shape).astype(np.float32)
80+
expected = (x @ w.T).reshape((-1, 1))
81+
feeds = {"X": x}
82+
got = ref.run(None, feeds)
83+
self.assertEqualArray(expected, got[0])
84+
85+
def test_simple_big(self):
86+
with contextlib.redirect_stdout(io.StringIO()):
87+
g = GraphBuilder(verbose=10)
88+
89+
shape = (30, 40)
90+
w = np.random.randn(*shape).astype(np.float32)
91+
92+
x = g.make_tensor_input("X", np.float32, shape)
93+
weight = g.make_initializer(w)
94+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
95+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
96+
res = g.op.MatMul(x, transposed)
97+
g.op.Reshape(res, one, outputs="y")
98+
g.make_tensor_output("y", np.float32, (30, 1))
99+
onx = g.to_onnx()
100+
ref = ReferenceEvaluator(onx)
101+
x = np.random.randn(*shape).astype(np.float32)
102+
expected = (x @ w.T).reshape((-1, 1))
103+
feeds = {"X": x}
104+
got = ref.run(None, feeds)
105+
self.assertEqualArray(expected, got[0])
106+
107+
def test_constant_folding(self):
108+
with contextlib.redirect_stdout(io.StringIO()):
109+
g = GraphBuilder(verbose=10)
110+
111+
shape = (10, 4)
112+
w = np.random.randn(*shape).astype(np.float32)
113+
x = g.make_tensor_input("X", np.float32, shape)
114+
weight = g.make_initializer(w)
115+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
116+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
117+
res = g.op.MatMul(x, transposed)
118+
g.op.Reshape(res, one, outputs="y")
119+
g.make_tensor_output("y", np.float32, (10, 1))
120+
121+
g.constant_folding()
122+
123+
onx = g.to_onnx()
124+
node_types = [n.op_type for n in onx.graph.node]
125+
self.assertNotIn("Transpose", node_types)
126+
ref = ReferenceEvaluator(onx)
127+
x = np.random.randn(*shape).astype(np.float32)
128+
expected = (x @ w.T).reshape((-1, 1))
129+
feeds = {"X": x}
130+
got = ref.run(None, feeds)
131+
self.assertEqualArray(expected, got[0])
132+
133+
def test_remove_identity(self):
134+
with contextlib.redirect_stdout(io.StringIO()):
135+
g = GraphBuilder(verbose=10)
136+
137+
shape = (10, 4)
138+
w = np.random.randn(*shape).astype(np.float32)
139+
x = g.make_tensor_input("X", np.float32, shape)
140+
weight = g.make_initializer(w)
141+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
142+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
143+
res = g.op.Identity(g.op.MatMul(x, transposed))
144+
g.op.Reshape(res, one, outputs="y")
145+
g.make_tensor_output("y", np.float32, (10, 1))
146+
147+
g.remove_identity_nodes()
148+
149+
onx = g.to_onnx()
150+
node_types = [n.op_type for n in onx.graph.node]
151+
self.assertNotIn("Identity", node_types)
152+
ref = ReferenceEvaluator(onx)
153+
x = np.random.randn(*shape).astype(np.float32)
154+
expected = (x @ w.T).reshape((-1, 1))
155+
feeds = {"X": x}
156+
got = ref.run(None, feeds)
157+
self.assertEqualArray(expected, got[0])
158+
159+
def test_remove_identity_input(self):
160+
with contextlib.redirect_stdout(io.StringIO()):
161+
g = GraphBuilder(verbose=10)
162+
163+
shape = (10, 4)
164+
w = np.random.randn(*shape).astype(np.float32)
165+
x = g.make_tensor_input("X", np.float32, shape)
166+
x = g.op.Identity(x)
167+
weight = g.make_initializer(w)
168+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
169+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
170+
res = g.op.MatMul(x, transposed)
171+
g.op.Reshape(res, one, outputs="y")
172+
g.make_tensor_output("y", np.float32, (10, 1))
173+
174+
g.remove_identity_nodes()
175+
176+
onx = g.to_onnx()
177+
node_types = [n.op_type for n in onx.graph.node]
178+
self.assertNotIn("Identity", node_types)
179+
ref = ReferenceEvaluator(onx)
180+
x = np.random.randn(*shape).astype(np.float32)
181+
expected = (x @ w.T).reshape((-1, 1))
182+
feeds = {"X": x}
183+
got = ref.run(None, feeds)
184+
self.assertEqualArray(expected, got[0])
185+
186+
def test_remove_identity_output(self):
187+
with contextlib.redirect_stdout(io.StringIO()):
188+
g = GraphBuilder(verbose=10)
189+
190+
shape = (10, 4)
191+
w = np.random.randn(*shape).astype(np.float32)
192+
x = g.make_tensor_input("X", np.float32, shape)
193+
weight = g.make_initializer(w)
194+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
195+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
196+
res = g.op.MatMul(x, transposed)
197+
r = g.op.Reshape(res, one)
198+
g.op.Identity(r, outputs=["y"])
199+
g.make_tensor_output("y", np.float32, (10, 1))
200+
201+
g.remove_identity_nodes()
202+
203+
onx = g.to_onnx()
204+
node_types = [n.op_type for n in onx.graph.node]
205+
self.assertNotIn("Identity", node_types)
206+
ref = ReferenceEvaluator(onx)
207+
x = np.random.randn(*shape).astype(np.float32)
208+
expected = (x @ w.T).reshape((-1, 1))
209+
feeds = {"X": x}
210+
got = ref.run(None, feeds)
211+
self.assertEqualArray(expected, got[0])
212+
213+
def test_remove_unused_nodes_simple(self):
214+
with contextlib.redirect_stdout(io.StringIO()):
215+
g = GraphBuilder(verbose=10)
216+
217+
shape = (10, 4)
218+
w = np.random.randn(*shape).astype(np.float32)
219+
x = g.make_tensor_input("X", np.float32, shape)
220+
weight = g.make_initializer(w)
221+
cst = g.make_initializer(np.array([2], dtype=np.float32))
222+
one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
223+
transposed = g.make_node("Transpose", [weight], perm=[1, 0])
224+
res = g.op.MatMul(x, transposed)
225+
g.op.Add(res, cst)
226+
g.op.Reshape(res, one, outputs=["y"])
227+
g.make_tensor_output("y", np.float32, (10, 1))
228+
229+
g.remove_identity_nodes()
230+
231+
onx = g.to_onnx()
232+
node_types = [n.op_type for n in onx.graph.node]
233+
self.assertNotIn("Add", node_types)
234+
ref = ReferenceEvaluator(onx)
235+
x = np.random.randn(*shape).astype(np.float32)
236+
expected = (x @ w.T).reshape((-1, 1))
237+
feeds = {"X": x}
238+
got = ref.run(None, feeds)
239+
self.assertEqualArray(expected, got[0])
240+
241+
242+
if __name__ == "__main__":
243+
unittest.main(verbosity=2)

_unittests/ut_graph_api/test_graph_builder_optim.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,56 +6,8 @@
66
from onnx_array_api.graph_api.graph_builder import GraphBuilder
77

88

9-
class TestGraphSimplification(ExtTestCase):
10-
def call_optimizer(self, onx):
11-
gr = GraphBuilder(onx)
12-
gr.remove_unused()
13-
return gr.to_onnx()
14-
15-
def test_remove_unused_nodes(self):
16-
model = onnx.parser.parse_model(
17-
"""
18-
<ir_version: 8, opset_import: [ "": 18]>
19-
agraph (float[N] x) => (float[N] z) {
20-
two = Constant <value_float=2.0> ()
21-
four = Add(two, two)
22-
z = Mul(x, x)
23-
}"""
24-
)
25-
onx = self.call_optimizer(model)
26-
self.assertEqual(len(onx.graph.node), 1)
27-
self.assertEqual(onx.graph.node[0].op_type, "Mul")
28-
29-
def test_initializers(self):
30-
model = onnx.parser.parse_model(
31-
"""
32-
<ir_version: 8, opset_import: [ "": 18]>
33-
agraph (float[N] x) => (float[N] z)
34-
<float two = {2.0}> {
35-
four = Add(two, two)
36-
z = Mul(x, x)
37-
}"""
38-
)
39-
self.assertEqual(len(model.graph.initializer), 1)
40-
onx = self.call_optimizer(model)
41-
self.assertEqual(len(onx.graph.node), 1)
42-
self.assertEqual(onx.graph.node[0].op_type, "Mul")
43-
self.assertEqual(len(onx.graph.initializer), 0)
44-
45-
def test_keep_unused_outputs(self):
46-
model = onnx.parser.parse_model(
47-
"""
48-
<ir_version: 8, opset_import: [ "": 18]>
49-
agraph (float[N] x) => (float[M] z) {
50-
w1, w2, w3 = Split (x)
51-
z = Mul(w3, w3)
52-
}"""
53-
)
54-
onx = self.call_optimizer(model)
55-
self.assertEqual(len(onx.graph.node), 2)
56-
self.assertEqual(onx.graph.node[0].op_type, "Split")
57-
58-
def test_check_afiles(self):
9+
class TestGraphBuilderOptim(ExtTestCase):
10+
def test_wcheck_afiles(self):
5911
import onnxruntime
6012

6113
data = os.path.join(os.path.dirname(__file__), "data")

0 commit comments

Comments
 (0)