Skip to content

Commit ebafa26

Browse files
authored
Adds function to plot onnx model as graphs (#61)
* Add methods to draw onnx plots * improve versatility * doc * disable test when graphviz not installed * documentation * add missing function
1 parent 7895c27 commit ebafa26

File tree

5 files changed

+307
-0
lines changed

5 files changed

+307
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ build/*
1414
*egg-info/*
1515
onnxruntime_profile*
1616
prof
17+
test*.png
1718
_doc/sg_execution_times.rst
1819
_doc/auto_examples/*
1920
_doc/examples/_cache/*

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`61`: adds function to plot onnx model as graphs
78
* :pr:`60`: supports translation of local functions
89
* :pr:`59`: add methods to update nodes in GraphAPI
910

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import unittest
3+
import onnx.parser
4+
from onnx_array_api.ext_test_case import (
5+
ExtTestCase,
6+
skipif_ci_windows,
7+
skipif_ci_apple,
8+
)
9+
from onnx_array_api.plotting.dot_plot import to_dot
10+
from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot
11+
12+
13+
class TestGraphviz(ExtTestCase):
14+
@classmethod
15+
def _get_graph(cls):
16+
return onnx.parser.parse_model(
17+
"""
18+
<ir_version: 8, opset_import: [ "": 18]>
19+
agraph (float[N] x) => (float[N] z) {
20+
two = Constant <value_float=2.0> ()
21+
four = Add(two, two)
22+
z = Mul(x, x)
23+
}"""
24+
)
25+
26+
@skipif_ci_windows("graphviz not installed")
27+
@skipif_ci_apple("graphviz not installed")
28+
def test_draw_graph_graphviz(self):
29+
fout = "test_draw_graph_graphviz.png"
30+
dot = to_dot(self._get_graph())
31+
draw_graph_graphviz(dot, image=fout)
32+
self.assertExists(os.path.exists(fout))
33+
34+
@skipif_ci_windows("graphviz not installed")
35+
@skipif_ci_apple("graphviz not installed")
36+
def test_draw_graph_graphviz_proto(self):
37+
fout = "test_draw_graph_graphviz_proto.png"
38+
dot = self._get_graph()
39+
draw_graph_graphviz(dot, image=fout)
40+
self.assertExists(os.path.exists(fout))
41+
42+
@skipif_ci_windows("graphviz not installed")
43+
@skipif_ci_apple("graphviz not installed")
44+
def test_plot_dot(self):
45+
dot = to_dot(self._get_graph())
46+
ax = plot_dot(dot)
47+
ax.get_figure().savefig("test_plot_dot.png")
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main(verbosity=2)

onnx_array_api/ext_test_case.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def is_windows() -> bool:
1919
return sys.platform == "win32"
2020

2121

22+
def is_apple() -> bool:
23+
return sys.platform == "darwin"
24+
25+
2226
def skipif_ci_windows(msg) -> Callable:
2327
"""
2428
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
@@ -29,6 +33,16 @@ def skipif_ci_windows(msg) -> Callable:
2933
return lambda x: x
3034

3135

36+
def skipif_ci_apple(msg) -> Callable:
37+
"""
38+
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
39+
"""
40+
if is_apple() and is_azure():
41+
msg = f"Test does not work on azure pipeline (Apple). {msg}"
42+
return unittest.skip(msg)
43+
return lambda x: x
44+
45+
3246
def ignore_warnings(warns: List[Warning]) -> Callable:
3347
"""
3448
Catches warnings.
@@ -230,6 +244,10 @@ def assertEmpty(self, value: Any):
230244
return
231245
raise AssertionError(f"value is not empty: {value!r}.")
232246

247+
def assertExists(self, name):
248+
if not os.path.exists(name):
249+
raise AssertionError(f"File or folder {name!r} does not exists.")
250+
233251
def assertHasAttr(self, cls: type, name: str):
234252
if not hasattr(cls, name):
235253
raise AssertionError(f"Class {cls} has no attribute {name!r}.")
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import os
2+
import subprocess
3+
import sys
4+
import tempfile
5+
from typing import List, Optional, Tuple, Union
6+
import numpy as np
7+
from onnx import ModelProto
8+
9+
10+
def _find_in_PATH(prog: str) -> Optional[str]:
11+
"""
12+
Looks into every path mentioned in ``%PATH%`` a specific file,
13+
it raises an exception if not found.
14+
15+
:param prog: program to look for
16+
:return: path
17+
"""
18+
sep = ";" if sys.platform.startswith("win") else ":"
19+
path = os.environ["PATH"]
20+
for p in path.split(sep):
21+
f = os.path.join(p, prog)
22+
if os.path.exists(f):
23+
return p
24+
return None
25+
26+
27+
def _find_graphviz_dot(exc: bool = True) -> str:
28+
"""
29+
Determines the path to graphviz (on Windows),
30+
the function tests the existence of versions 34 to 45
31+
assuming it was installed in a standard folder:
32+
``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``.
33+
34+
:param exc: raise exception of be silent
35+
:return: path to dot
36+
:raises FileNotFoundError: if graphviz not found
37+
"""
38+
if sys.platform.startswith("win"):
39+
version = list(range(34, 60))
40+
version.extend([f"{v}.1" for v in version])
41+
for v in version:
42+
graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe"
43+
if os.path.exists(graphviz_dot):
44+
return graphviz_dot
45+
extra = ["build/update_modules/Graphviz/bin"]
46+
for ext in extra:
47+
graphviz_dot = os.path.join(ext, "dot.exe")
48+
if os.path.exists(graphviz_dot):
49+
return graphviz_dot
50+
p = _find_in_PATH("dot.exe")
51+
if p is None:
52+
if exc:
53+
raise FileNotFoundError(
54+
f"Unable to find graphviz, look into paths such as {graphviz_dot}."
55+
)
56+
return None
57+
return os.path.join(p, "dot.exe")
58+
# linux
59+
return "dot"
60+
61+
62+
def _run_subprocess(
63+
args: List[str],
64+
cwd: Optional[str] = None,
65+
):
66+
assert not isinstance(
67+
args, str
68+
), "args should be a sequence of strings, not a string."
69+
70+
p = subprocess.Popen(
71+
args,
72+
cwd=cwd,
73+
shell=False,
74+
env=os.environ,
75+
stdout=subprocess.PIPE,
76+
stderr=subprocess.STDOUT,
77+
)
78+
raise_exception = False
79+
output = ""
80+
while True:
81+
output = p.stdout.readline().decode(errors="ignore")
82+
if output == "" and p.poll() is not None:
83+
break
84+
if output:
85+
if (
86+
"fatal error" in output
87+
or "CMake Error" in output
88+
or "gmake: ***" in output
89+
or "): error C" in output
90+
or ": error: " in output
91+
):
92+
raise_exception = True
93+
p.poll()
94+
p.stdout.close()
95+
if raise_exception:
96+
raise RuntimeError(
97+
"An error was found in the output. The build is stopped.\n{output}"
98+
)
99+
return output
100+
101+
102+
def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str:
103+
"""
104+
Run :epkg:`Graphviz`.
105+
106+
:param filename: filename which contains the graph definition
107+
:param image: output image
108+
:param engine: *dot* or *neato*
109+
:return: output of graphviz
110+
"""
111+
ext = os.path.splitext(image)[-1]
112+
assert ext in {
113+
".png",
114+
".bmp",
115+
".fig",
116+
".gif",
117+
".ico",
118+
".jpg",
119+
".jpeg",
120+
".pdf",
121+
".ps",
122+
".svg",
123+
".vrml",
124+
".tif",
125+
".tiff",
126+
".wbmp",
127+
}, f"Unexpected extension {ext!r} for {image!r}."
128+
if sys.platform.startswith("win"):
129+
bin_ = os.path.dirname(_find_graphviz_dot())
130+
# if bin not in os.environ["PATH"]:
131+
# os.environ["PATH"] = os.environ["PATH"] + ";" + bin
132+
exe = os.path.join(bin_, engine)
133+
else:
134+
exe = engine
135+
if os.path.exists(image):
136+
os.remove(image)
137+
output = _run_subprocess([exe, f"-T{ext[1:]}", filename, "-o", image])
138+
assert os.path.exists(image), f"Graphviz failed due to {output}"
139+
return output
140+
141+
142+
def draw_graph_graphviz(
143+
dot: Union[str, ModelProto],
144+
image: str,
145+
engine: str = "dot",
146+
) -> str:
147+
"""
148+
Draws a graph using :epkg:`Graphviz`.
149+
150+
:param dot: dot graph or ModelProto
151+
:param image: output image, None, just returns the output
152+
:param engine: *dot* or *neato*
153+
:return: :epkg:`Graphviz` output or
154+
the dot text if *image* is None
155+
156+
The function creates a temporary file to store the dot file if *image* is not None.
157+
"""
158+
if isinstance(dot, ModelProto):
159+
from .dot_plot import to_dot
160+
161+
sdot = to_dot(dot)
162+
else:
163+
sdot = dot
164+
with tempfile.NamedTemporaryFile(delete=False) as fp:
165+
fp.write(sdot.encode("utf-8"))
166+
fp.close()
167+
168+
filename = fp.name
169+
assert os.path.exists(
170+
filename
171+
), f"File {filename!r} cannot be created to store the graph."
172+
out = _run_graphviz(filename, image, engine=engine)
173+
assert os.path.exists(
174+
image
175+
), f"Graphviz failed with no reason, {image!r} not found, output is {out}."
176+
os.remove(filename)
177+
return out
178+
179+
180+
def plot_dot(
181+
dot: Union[str, ModelProto],
182+
ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821
183+
engine: str = "dot",
184+
figsize: Optional[Tuple[int, int]] = None,
185+
) -> "matplotlib.axis.Axis": # noqa: F821
186+
"""
187+
Draws a dot graph into a matplotlib graph.
188+
189+
:param dot: dot graph or ModelProto
190+
:param image: output image, None, just returns the output
191+
:param engine: *dot* or *neato*
192+
:param figsize: figsize of ax is None
193+
:return: :epkg:`Graphviz` output or
194+
the dot text if *image* is None
195+
196+
.. plot::
197+
198+
import matplotlib.pyplot as plt
199+
import onnx.parser
200+
201+
model = onnx.parser.parse_model(
202+
'''
203+
<ir_version: 8, opset_import: [ "": 18]>
204+
agraph (float[N] x) => (float[N] z) {
205+
two = Constant <value_float=2.0> ()
206+
four = Add(two, two)
207+
z = Mul(four, four)
208+
}''')
209+
ax = plot_dot(dot)
210+
ax.set_title("Dummy graph")
211+
plt.show()
212+
"""
213+
if ax is None:
214+
import matplotlib.pyplot as plt
215+
216+
_, ax = plt.subplots(1, 1, figsize=figsize)
217+
clean = True
218+
else:
219+
clean = False
220+
221+
from PIL import Image
222+
223+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp:
224+
fp.close()
225+
226+
draw_graph_graphviz(dot, fp.name, engine=engine)
227+
img = np.asarray(Image.open(fp.name))
228+
os.remove(fp.name)
229+
230+
ax.imshow(img)
231+
232+
if clean:
233+
ax.get_xaxis().set_visible(False)
234+
ax.get_yaxis().set_visible(False)
235+
ax.get_figure().tight_layout()
236+
return ax

0 commit comments

Comments
 (0)