Skip to content

Adds tools to compare models #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ _doc/examples/plot_*.png
_doc/_static/require.js
_doc/_static/viz.js
_unittests/ut__main/*.png
_doc/examples/data/*.optimized.onnx
_doc/examples/*.html
_unittests/ut__main/_cache/*
_unittests/ut__main/*.html
2 changes: 1 addition & 1 deletion _doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ API
npx_jit
npx_annot
npx_numpy
onnx_tools
ort
plotting
tools

16 changes: 16 additions & 0 deletions _doc/api/onnx_tools.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.. _l-api-onnx-tools:

onnx tools
==========

Differences
+++++++++++

.. autofunction:: onnx_array_api.validation.diff.html_diff

.. autofunction:: onnx_array_api.validation.diff.text_diff

Protos
++++++

.. autofunction:: onnx_array_api.validation.tools.randomize_proto
7 changes: 5 additions & 2 deletions _doc/api/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ Benchmark

.. autofunction:: onnx_array_api.ext_test_case.measure_time

Examples
++++++++

.. autofunction:: onnx_array_api.ext_test_case.example_path

Profiling
+++++++++

Expand All @@ -25,5 +30,3 @@ Unit tests

.. autoclass:: onnx_array_api.ext_test_case.ExtTestCase
:members:


Binary file added _doc/examples/data/small.onnx
Binary file not shown.
121 changes: 121 additions & 0 deletions _doc/examples/plot_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""

.. _l-onnx-array-onnxruntime-optimization:

Optimization with onnxruntime
=============================


Optimize a model with onnxruntime
+++++++++++++++++++++++++++++++++
"""
import os
from pprint import pprint
import numpy
from pandas import DataFrame
import matplotlib.pyplot as plt
from onnx import load
from onnx_array_api.ext_test_case import example_path
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_array_api.validation.diff import text_diff, html_diff
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
from onnx_array_api.ext_test_case import measure_time
from onnx_array_api.ort.ort_optimizers import ort_optimized_model


filename = example_path("data/small.onnx")
optimized = filename + ".optimized.onnx"

if not os.path.exists(optimized):
ort_optimized_model(filename, output=optimized)
print(optimized)

#############################
# Output comparison
# +++++++++++++++++

so = SessionOptions()
so.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)

sess = InferenceSession(filename, so)
sess_opt = InferenceSession(optimized, so)
input_name = sess.get_inputs()[0].name
out = sess.run(None, {input_name: img})[0]
out_opt = sess_opt.run(None, {input_name: img})[0]
if out.shape != out_opt.shape:
print("ERROR shape are different {out.shape} != {out_opt.shape}")
diff = numpy.abs(out - out_opt).max()
print(f"Differences: {diff}")

####################################
# Difference
# ++++++++++
#
# Unoptimized model.

with open(filename, "rb") as f:
model = load(f)
print("first model to text...")
text1 = onnx_simple_text_plot(model, indent=False)
print(text1)

#####################################
# Optimized model.


with open(optimized, "rb") as f:
model = load(f)
print("second model to text...")
text2 = onnx_simple_text_plot(model, indent=False)
print(text2)

########################################
# Differences

print("differences...")
print(text_diff(text1, text2))

#####################################
# HTML version.

print("html differences...")
output = html_diff(text1, text2)
with open("diff_html.html", "w", encoding="utf-8") as f:
f.write(output)
print("done.")

#####################################
# Benchmark
# +++++++++

img = numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)

t1 = measure_time(lambda: sess.run(None, {input_name: img}), repeat=25, number=25)
t1["name"] = "original"
print("Original model")
pprint(t1)

t2 = measure_time(lambda: sess_opt.run(None, {input_name: img}), repeat=25, number=25)
t2["name"] = "optimized"
print("Optimized")
pprint(t2)


############################
# Plots
# +++++


fig, ax = plt.subplots(1, 1, figsize=(12, 4))

df = DataFrame([t1, t2]).set_index("name")
print(df)

print(df["average"].values)
print((df["average"] - df["deviation"]).values)

ax.bar(df.index, df["average"].values, yerr=df["deviation"].values, capsize=6)
ax.set_title("Measure performance of optimized model\nlower is better")
plt.grid()
fig.savefig("plot_optimization.png")
Binary file added _unittests/ut_validation/data/small.onnx
Binary file not shown.
23 changes: 23 additions & 0 deletions _unittests/ut_validation/test_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest
from onnx import load
from onnx.checker import check_model
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
from onnx_array_api.validation.diff import text_diff, html_diff


class TestDiff(ExtTestCase):
def test_diff_optimized(self):
data = self.relative_path(__file__, "data", "small.onnx")
with open(data, "rb") as f:
model = load(f)
optimized = ort_optimized_model(model)
check_model(optimized)
diff = text_diff(model, optimized)
self.assertIn("^^^^^^^^^^^^^^^^", diff)
ht = html_diff(model, optimized)
self.assertIn("<html><body>", ht)


if __name__ == "__main__":
unittest.main(verbosity=2)
20 changes: 20 additions & 0 deletions _unittests/ut_validation/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
from onnx import load
from onnx.checker import check_model
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.validation.tools import randomize_proto


class TestTools(ExtTestCase):
def test_randomize_proto(self):
data = self.relative_path(__file__, "data", "small.onnx")
with open(data, "rb") as f:
model = load(f)
check_model(model)
rnd = randomize_proto(model)
self.assertEqual(len(model.SerializeToString()), len(rnd.SerializeToString()))
check_model(rnd)


if __name__ == "__main__":
unittest.main(verbosity=2)
30 changes: 30 additions & 0 deletions onnx_array_api/ext_test_case.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
import unittest
import warnings
Expand Down Expand Up @@ -30,6 +31,20 @@ def call_f(self):
return wrapper


def example_path(path: str) -> str:
"""
Fixes a path for the examples.
Helps running the example within a unit test.
"""
if os.path.exists(path):
return path
this = os.path.abspath(os.path.dirname(__file__))
full = os.path.join(this, "..", "_doc", "examples", path)
if os.path.exists(full):
return full
raise FileNotFoundError(f"Unable to find path {path!r} or {full!r}.")


def measure_time(
stmt: Callable,
context: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -207,3 +222,18 @@ def capture(self, fct: Callable):
with redirect_stderr(serr):
res = fct()
return res, sout.getvalue(), serr.getvalue()

def relative_path(self, filename: str, *names: List[str]) -> str:
"""
Returns a path relative to the folder *filename*
is in. The function checks the path existence.

:param filename: filename
:param names: additional path pieces
:return: new path
"""
dir = os.path.abspath(os.path.dirname(filename))
name = os.path.join(dir, *names)
if not os.path.exists(name):
raise FileNotFoundError(f"Path {name!r} does not exists.")
return name
16 changes: 12 additions & 4 deletions onnx_array_api/ort/ort_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Union
from typing import Union, Optional
from onnx import ModelProto, load
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi._pybind_state import GraphOptimizationLevel
from ..cache import get_cache_file


def ort_optimized_model(
onx: Union[str, ModelProto], level: str = "ORT_ENABLE_ALL"
onx: Union[str, ModelProto],
level: str = "ORT_ENABLE_ALL",
output: Optional[str] = None,
) -> Union[str, ModelProto]:
"""
Returns the optimized model used by onnxruntime before
Expand All @@ -15,6 +17,7 @@ def ort_optimized_model(
:param onx: ModelProto
:param level: optimization level, `'ORT_ENABLE_BASIC'`,
`'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'`
:param output: output file if the proposed cache is not wanted
:return: optimized model
"""
glevel = getattr(GraphOptimizationLevel, level, None)
Expand All @@ -23,13 +26,18 @@ def ort_optimized_model(
f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}."
)

cache = get_cache_file("ort_optimized_model.onnx", remove=True)
if output is not None:
cache = output
else:
cache = get_cache_file("ort_optimized_model.onnx", remove=True)
so = SessionOptions()
so.graph_optimization_level = glevel
so.optimized_model_filepath = str(cache)
InferenceSession(onx if isinstance(onx, str) else onx.SerializeToString(), so)
if not cache.exists():
if output is None and not cache.exists():
raise RuntimeError(f"The optimized model {str(cache)!r} not found.")
if output is not None:
return output
if isinstance(onx, str):
return str(cache)
opt_onx = load(str(cache))
Expand Down
Loading