Skip to content

Adds function to plot onnx model as graphs #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ build/*
*egg-info/*
onnxruntime_profile*
prof
test*.png
_doc/sg_execution_times.rst
_doc/auto_examples/*
_doc/examples/_cache/*
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`61`: adds function to plot onnx model as graphs
* :pr:`60`: supports translation of local functions
* :pr:`59`: add methods to update nodes in GraphAPI

Expand Down
51 changes: 51 additions & 0 deletions _unittests/ut_plotting/test_graphviz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import unittest
import onnx.parser
from onnx_array_api.ext_test_case import (
ExtTestCase,
skipif_ci_windows,
skipif_ci_apple,
)
from onnx_array_api.plotting.dot_plot import to_dot
from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot


class TestGraphviz(ExtTestCase):
@classmethod
def _get_graph(cls):
return onnx.parser.parse_model(
"""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, x)
}"""
)

@skipif_ci_windows("graphviz not installed")
@skipif_ci_apple("graphviz not installed")
def test_draw_graph_graphviz(self):
fout = "test_draw_graph_graphviz.png"
dot = to_dot(self._get_graph())
draw_graph_graphviz(dot, image=fout)
self.assertExists(os.path.exists(fout))

@skipif_ci_windows("graphviz not installed")
@skipif_ci_apple("graphviz not installed")
def test_draw_graph_graphviz_proto(self):
fout = "test_draw_graph_graphviz_proto.png"
dot = self._get_graph()
draw_graph_graphviz(dot, image=fout)
self.assertExists(os.path.exists(fout))

@skipif_ci_windows("graphviz not installed")
@skipif_ci_apple("graphviz not installed")
def test_plot_dot(self):
dot = to_dot(self._get_graph())
ax = plot_dot(dot)
ax.get_figure().savefig("test_plot_dot.png")


if __name__ == "__main__":
unittest.main(verbosity=2)
18 changes: 18 additions & 0 deletions onnx_array_api/ext_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def is_windows() -> bool:
return sys.platform == "win32"


def is_apple() -> bool:
return sys.platform == "darwin"


def skipif_ci_windows(msg) -> Callable:
"""
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
Expand All @@ -29,6 +33,16 @@ def skipif_ci_windows(msg) -> Callable:
return lambda x: x


def skipif_ci_apple(msg) -> Callable:
"""
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
"""
if is_apple() and is_azure():
msg = f"Test does not work on azure pipeline (Apple). {msg}"
return unittest.skip(msg)
return lambda x: x


def ignore_warnings(warns: List[Warning]) -> Callable:
"""
Catches warnings.
Expand Down Expand Up @@ -230,6 +244,10 @@ def assertEmpty(self, value: Any):
return
raise AssertionError(f"value is not empty: {value!r}.")

def assertExists(self, name):
if not os.path.exists(name):
raise AssertionError(f"File or folder {name!r} does not exists.")

def assertHasAttr(self, cls: type, name: str):
if not hasattr(cls, name):
raise AssertionError(f"Class {cls} has no attribute {name!r}.")
Expand Down
236 changes: 236 additions & 0 deletions onnx_array_api/plotting/graphviz_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import os
import subprocess
import sys
import tempfile
from typing import List, Optional, Tuple, Union
import numpy as np
from onnx import ModelProto


def _find_in_PATH(prog: str) -> Optional[str]:
"""
Looks into every path mentioned in ``%PATH%`` a specific file,
it raises an exception if not found.

:param prog: program to look for
:return: path
"""
sep = ";" if sys.platform.startswith("win") else ":"
path = os.environ["PATH"]
for p in path.split(sep):
f = os.path.join(p, prog)
if os.path.exists(f):
return p
return None


def _find_graphviz_dot(exc: bool = True) -> str:
"""
Determines the path to graphviz (on Windows),
the function tests the existence of versions 34 to 45
assuming it was installed in a standard folder:
``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``.

:param exc: raise exception of be silent
:return: path to dot
:raises FileNotFoundError: if graphviz not found
"""
if sys.platform.startswith("win"):
version = list(range(34, 60))
version.extend([f"{v}.1" for v in version])
for v in version:
graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe"
if os.path.exists(graphviz_dot):
return graphviz_dot
extra = ["build/update_modules/Graphviz/bin"]
for ext in extra:
graphviz_dot = os.path.join(ext, "dot.exe")
if os.path.exists(graphviz_dot):
return graphviz_dot
p = _find_in_PATH("dot.exe")
if p is None:
if exc:
raise FileNotFoundError(
f"Unable to find graphviz, look into paths such as {graphviz_dot}."
)
return None
return os.path.join(p, "dot.exe")
# linux
return "dot"


def _run_subprocess(
args: List[str],
cwd: Optional[str] = None,
):
assert not isinstance(
args, str
), "args should be a sequence of strings, not a string."

p = subprocess.Popen(
args,
cwd=cwd,
shell=False,
env=os.environ,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
raise_exception = False
output = ""
while True:
output = p.stdout.readline().decode(errors="ignore")
if output == "" and p.poll() is not None:
break
if output:
if (
"fatal error" in output
or "CMake Error" in output
or "gmake: ***" in output
or "): error C" in output
or ": error: " in output
):
raise_exception = True
p.poll()
p.stdout.close()
if raise_exception:
raise RuntimeError(
"An error was found in the output. The build is stopped.\n{output}"
)
return output


def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str:
"""
Run :epkg:`Graphviz`.

:param filename: filename which contains the graph definition
:param image: output image
:param engine: *dot* or *neato*
:return: output of graphviz
"""
ext = os.path.splitext(image)[-1]
assert ext in {
".png",
".bmp",
".fig",
".gif",
".ico",
".jpg",
".jpeg",
".pdf",
".ps",
".svg",
".vrml",
".tif",
".tiff",
".wbmp",
}, f"Unexpected extension {ext!r} for {image!r}."
if sys.platform.startswith("win"):
bin_ = os.path.dirname(_find_graphviz_dot())
# if bin not in os.environ["PATH"]:
# os.environ["PATH"] = os.environ["PATH"] + ";" + bin
exe = os.path.join(bin_, engine)
else:
exe = engine
if os.path.exists(image):
os.remove(image)
output = _run_subprocess([exe, f"-T{ext[1:]}", filename, "-o", image])
assert os.path.exists(image), f"Graphviz failed due to {output}"
return output


def draw_graph_graphviz(
dot: Union[str, ModelProto],
image: str,
engine: str = "dot",
) -> str:
"""
Draws a graph using :epkg:`Graphviz`.

:param dot: dot graph or ModelProto
:param image: output image, None, just returns the output
:param engine: *dot* or *neato*
:return: :epkg:`Graphviz` output or
the dot text if *image* is None

The function creates a temporary file to store the dot file if *image* is not None.
"""
if isinstance(dot, ModelProto):
from .dot_plot import to_dot

sdot = to_dot(dot)
else:
sdot = dot
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(sdot.encode("utf-8"))
fp.close()

filename = fp.name
assert os.path.exists(
filename
), f"File {filename!r} cannot be created to store the graph."
out = _run_graphviz(filename, image, engine=engine)
assert os.path.exists(
image
), f"Graphviz failed with no reason, {image!r} not found, output is {out}."
os.remove(filename)
return out


def plot_dot(
dot: Union[str, ModelProto],
ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821
engine: str = "dot",
figsize: Optional[Tuple[int, int]] = None,
) -> "matplotlib.axis.Axis": # noqa: F821
"""
Draws a dot graph into a matplotlib graph.

:param dot: dot graph or ModelProto
:param image: output image, None, just returns the output
:param engine: *dot* or *neato*
:param figsize: figsize of ax is None
:return: :epkg:`Graphviz` output or
the dot text if *image* is None

.. plot::

import matplotlib.pyplot as plt
import onnx.parser

model = onnx.parser.parse_model(
'''
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(four, four)
}''')
ax = plot_dot(dot)
ax.set_title("Dummy graph")
plt.show()
"""
if ax is None:
import matplotlib.pyplot as plt

_, ax = plt.subplots(1, 1, figsize=figsize)
clean = True
else:
clean = False

from PIL import Image

with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp:
fp.close()

draw_graph_graphviz(dot, fp.name, engine=engine)
img = np.asarray(Image.open(fp.name))
os.remove(fp.name)

ax.imshow(img)

if clean:
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.get_figure().tight_layout()
return ax