Skip to content

Commit 718a0a5

Browse files
authored
Adds tools to compare models (#11)
* Adds tools to compare models * update path
1 parent ac28cb9 commit 718a0a5

File tree

16 files changed

+447
-8
lines changed

16 files changed

+447
-8
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ _doc/examples/plot_*.png
1515
_doc/_static/require.js
1616
_doc/_static/viz.js
1717
_unittests/ut__main/*.png
18+
_doc/examples/data/*.optimized.onnx
19+
_doc/examples/*.html
20+
_unittests/ut__main/_cache/*
21+
_unittests/ut__main/*.html

_doc/api/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ API
1111
npx_jit
1212
npx_annot
1313
npx_numpy
14+
onnx_tools
1415
ort
1516
plotting
1617
tools
17-

_doc/api/onnx_tools.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.. _l-api-onnx-tools:
2+
3+
onnx tools
4+
==========
5+
6+
Differences
7+
+++++++++++
8+
9+
.. autofunction:: onnx_array_api.validation.diff.html_diff
10+
11+
.. autofunction:: onnx_array_api.validation.diff.text_diff
12+
13+
Protos
14+
++++++
15+
16+
.. autofunction:: onnx_array_api.validation.tools.randomize_proto

_doc/api/tools.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ Benchmark
88

99
.. autofunction:: onnx_array_api.ext_test_case.measure_time
1010

11+
Examples
12+
++++++++
13+
14+
.. autofunction:: onnx_array_api.ext_test_case.example_path
15+
1116
Profiling
1217
+++++++++
1318

@@ -25,5 +30,3 @@ Unit tests
2530

2631
.. autoclass:: onnx_array_api.ext_test_case.ExtTestCase
2732
:members:
28-
29-

_doc/examples/data/small.onnx

315 KB
Binary file not shown.

_doc/examples/plot_optimization.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
3+
.. _l-onnx-array-onnxruntime-optimization:
4+
5+
Optimization with onnxruntime
6+
=============================
7+
8+
9+
Optimize a model with onnxruntime
10+
+++++++++++++++++++++++++++++++++
11+
"""
12+
import os
13+
from pprint import pprint
14+
import numpy
15+
from pandas import DataFrame
16+
import matplotlib.pyplot as plt
17+
from onnx import load
18+
from onnx_array_api.ext_test_case import example_path
19+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
20+
from onnx_array_api.validation.diff import text_diff, html_diff
21+
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
22+
from onnx_array_api.ext_test_case import measure_time
23+
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
24+
25+
26+
filename = example_path("data/small.onnx")
27+
optimized = filename + ".optimized.onnx"
28+
29+
if not os.path.exists(optimized):
30+
ort_optimized_model(filename, output=optimized)
31+
print(optimized)
32+
33+
#############################
34+
# Output comparison
35+
# +++++++++++++++++
36+
37+
so = SessionOptions()
38+
so.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
39+
img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)
40+
41+
sess = InferenceSession(filename, so)
42+
sess_opt = InferenceSession(optimized, so)
43+
input_name = sess.get_inputs()[0].name
44+
out = sess.run(None, {input_name: img})[0]
45+
out_opt = sess_opt.run(None, {input_name: img})[0]
46+
if out.shape != out_opt.shape:
47+
print("ERROR shape are different {out.shape} != {out_opt.shape}")
48+
diff = numpy.abs(out - out_opt).max()
49+
print(f"Differences: {diff}")
50+
51+
####################################
52+
# Difference
53+
# ++++++++++
54+
#
55+
# Unoptimized model.
56+
57+
with open(filename, "rb") as f:
58+
model = load(f)
59+
print("first model to text...")
60+
text1 = onnx_simple_text_plot(model, indent=False)
61+
print(text1)
62+
63+
#####################################
64+
# Optimized model.
65+
66+
67+
with open(optimized, "rb") as f:
68+
model = load(f)
69+
print("second model to text...")
70+
text2 = onnx_simple_text_plot(model, indent=False)
71+
print(text2)
72+
73+
########################################
74+
# Differences
75+
76+
print("differences...")
77+
print(text_diff(text1, text2))
78+
79+
#####################################
80+
# HTML version.
81+
82+
print("html differences...")
83+
output = html_diff(text1, text2)
84+
with open("diff_html.html", "w", encoding="utf-8") as f:
85+
f.write(output)
86+
print("done.")
87+
88+
#####################################
89+
# Benchmark
90+
# +++++++++
91+
92+
img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)
93+
94+
t1 = measure_time(lambda: sess.run(None, {input_name: img}), repeat=25, number=25)
95+
t1["name"] = "original"
96+
print("Original model")
97+
pprint(t1)
98+
99+
t2 = measure_time(lambda: sess_opt.run(None, {input_name: img}), repeat=25, number=25)
100+
t2["name"] = "optimized"
101+
print("Optimized")
102+
pprint(t2)
103+
104+
105+
############################
106+
# Plots
107+
# +++++
108+
109+
110+
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
111+
112+
df = DataFrame([t1, t2]).set_index("name")
113+
print(df)
114+
115+
print(df["average"].values)
116+
print((df["average"] - df["deviation"]).values)
117+
118+
ax.bar(df.index, df["average"].values, yerr=df["deviation"].values, capsize=6)
119+
ax.set_title("Measure performance of optimized model\nlower is better")
120+
plt.grid()
121+
fig.savefig("plot_optimization.png")
315 KB
Binary file not shown.

_unittests/ut_validation/test_diff.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
from onnx import load
3+
from onnx.checker import check_model
4+
from onnx_array_api.ext_test_case import ExtTestCase
5+
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
6+
from onnx_array_api.validation.diff import text_diff, html_diff
7+
8+
9+
class TestDiff(ExtTestCase):
10+
def test_diff_optimized(self):
11+
data = self.relative_path(__file__, "data", "small.onnx")
12+
with open(data, "rb") as f:
13+
model = load(f)
14+
optimized = ort_optimized_model(model)
15+
check_model(optimized)
16+
diff = text_diff(model, optimized)
17+
self.assertIn("^^^^^^^^^^^^^^^^", diff)
18+
ht = html_diff(model, optimized)
19+
self.assertIn("<html><body>", ht)
20+
21+
22+
if __name__ == "__main__":
23+
unittest.main(verbosity=2)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
from onnx import load
3+
from onnx.checker import check_model
4+
from onnx_array_api.ext_test_case import ExtTestCase
5+
from onnx_array_api.validation.tools import randomize_proto
6+
7+
8+
class TestTools(ExtTestCase):
9+
def test_randomize_proto(self):
10+
data = self.relative_path(__file__, "data", "small.onnx")
11+
with open(data, "rb") as f:
12+
model = load(f)
13+
check_model(model)
14+
rnd = randomize_proto(model)
15+
self.assertEqual(len(model.SerializeToString()), len(rnd.SerializeToString()))
16+
check_model(rnd)
17+
18+
19+
if __name__ == "__main__":
20+
unittest.main(verbosity=2)

onnx_array_api/ext_test_case.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import sys
23
import unittest
34
import warnings
@@ -30,6 +31,20 @@ def call_f(self):
3031
return wrapper
3132

3233

34+
def example_path(path: str) -> str:
35+
"""
36+
Fixes a path for the examples.
37+
Helps running the example within a unit test.
38+
"""
39+
if os.path.exists(path):
40+
return path
41+
this = os.path.abspath(os.path.dirname(__file__))
42+
full = os.path.join(this, "..", "_doc", "examples", path)
43+
if os.path.exists(full):
44+
return full
45+
raise FileNotFoundError(f"Unable to find path {path!r} or {full!r}.")
46+
47+
3348
def measure_time(
3449
stmt: Callable,
3550
context: Optional[Dict[str, Any]] = None,
@@ -207,3 +222,18 @@ def capture(self, fct: Callable):
207222
with redirect_stderr(serr):
208223
res = fct()
209224
return res, sout.getvalue(), serr.getvalue()
225+
226+
def relative_path(self, filename: str, *names: List[str]) -> str:
227+
"""
228+
Returns a path relative to the folder *filename*
229+
is in. The function checks the path existence.
230+
231+
:param filename: filename
232+
:param names: additional path pieces
233+
:return: new path
234+
"""
235+
dir = os.path.abspath(os.path.dirname(filename))
236+
name = os.path.join(dir, *names)
237+
if not os.path.exists(name):
238+
raise FileNotFoundError(f"Path {name!r} does not exists.")
239+
return name

onnx_array_api/ort/ort_optimizers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Union
1+
from typing import Union, Optional
22
from onnx import ModelProto, load
33
from onnxruntime import InferenceSession, SessionOptions
44
from onnxruntime.capi._pybind_state import GraphOptimizationLevel
55
from ..cache import get_cache_file
66

77

88
def ort_optimized_model(
9-
onx: Union[str, ModelProto], level: str = "ORT_ENABLE_ALL"
9+
onx: Union[str, ModelProto],
10+
level: str = "ORT_ENABLE_ALL",
11+
output: Optional[str] = None,
1012
) -> Union[str, ModelProto]:
1113
"""
1214
Returns the optimized model used by onnxruntime before
@@ -15,6 +17,7 @@ def ort_optimized_model(
1517
:param onx: ModelProto
1618
:param level: optimization level, `'ORT_ENABLE_BASIC'`,
1719
`'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'`
20+
:param output: output file if the proposed cache is not wanted
1821
:return: optimized model
1922
"""
2023
glevel = getattr(GraphOptimizationLevel, level, None)
@@ -23,13 +26,18 @@ def ort_optimized_model(
2326
f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}."
2427
)
2528

26-
cache = get_cache_file("ort_optimized_model.onnx", remove=True)
29+
if output is not None:
30+
cache = output
31+
else:
32+
cache = get_cache_file("ort_optimized_model.onnx", remove=True)
2733
so = SessionOptions()
2834
so.graph_optimization_level = glevel
2935
so.optimized_model_filepath = str(cache)
3036
InferenceSession(onx if isinstance(onx, str) else onx.SerializeToString(), so)
31-
if not cache.exists():
37+
if output is None and not cache.exists():
3238
raise RuntimeError(f"The optimized model {str(cache)!r} not found.")
39+
if output is not None:
40+
return output
3341
if isinstance(onx, str):
3442
return str(cache)
3543
opt_onx = load(str(cache))

0 commit comments

Comments
 (0)