From 470f8451349a3da823665a8d001c6c9915770d21 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:34:53 +0100 Subject: [PATCH 1/6] Add methods to draw onnx plots --- _unittests/ut_plotting/test_graphviz.py | 35 ++++ onnx_array_api/ext_test_case.py | 4 + onnx_array_api/plotting/graphviz_helper.py | 213 +++++++++++++++++++++ 3 files changed, 252 insertions(+) create mode 100644 _unittests/ut_plotting/test_graphviz.py create mode 100644 onnx_array_api/plotting/graphviz_helper.py diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py new file mode 100644 index 0000000..3a61b59 --- /dev/null +++ b/_unittests/ut_plotting/test_graphviz.py @@ -0,0 +1,35 @@ +import os +import unittest +import onnx.parser +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.plotting.dot_plot import to_dot +from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot + + +class TestGraphviz(ExtTestCase): + @classmethod + def _get_graph(cls): + return onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, x) + }""" + ) + + def test_draw_graph_graphviz(self): + fout = "test_draw_graph_graphviz.png" + dot = to_dot(self._get_graph()) + draw_graph_graphviz(dot, image=fout) + self.assertExists(os.path.exists(fout)) + + def test_plot_dot(self): + dot = to_dot(self._get_graph()) + ax = plot_dot(dot) + ax.get_figure().savefig("test_plot_dot.png") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index 1068bda..7555cb5 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -230,6 +230,10 @@ def assertEmpty(self, value: Any): return raise AssertionError(f"value is not empty: {value!r}.") + def assertExists(self, name): + if not os.path.exists(name): + raise AssertionError(f"File or folder {name!r} does not exists.") + def assertHasAttr(self, cls: type, name: str): if not hasattr(cls, name): raise AssertionError(f"Class {cls} has no attribute {name!r}.") diff --git a/onnx_array_api/plotting/graphviz_helper.py b/onnx_array_api/plotting/graphviz_helper.py new file mode 100644 index 0000000..813b694 --- /dev/null +++ b/onnx_array_api/plotting/graphviz_helper.py @@ -0,0 +1,213 @@ +import os +import subprocess +import sys +import tempfile +from typing import List, Optional, Tuple +import numpy as np + + +def _find_in_PATH(prog: str) -> Optional[str]: + """ + Looks into every path mentioned in ``%PATH%`` a specific file, + it raises an exception if not found. + + :param prog: program to look for + :return: path + """ + sep = ";" if sys.platform.startswith("win") else ":" + path = os.environ["PATH"] + for p in path.split(sep): + f = os.path.join(p, prog) + if os.path.exists(f): + return p + return None + + +def _find_graphviz_dot(exc: bool = True) -> str: + """ + Determines the path to graphviz (on Windows), + the function tests the existence of versions 34 to 45 + assuming it was installed in a standard folder: + ``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``. + + :param exc: raise exception of be silent + :return: path to dot + :raises FileNotFoundError: if graphviz not found + """ + if sys.platform.startswith("win"): + version = list(range(34, 60)) + version.extend([f"{v}.1" for v in version]) + for v in version: + graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe" + if os.path.exists(graphviz_dot): + return graphviz_dot + extra = ["build/update_modules/Graphviz/bin"] + for ext in extra: + graphviz_dot = os.path.join(ext, "dot.exe") + if os.path.exists(graphviz_dot): + return graphviz_dot + p = _find_in_PATH("dot.exe") + if p is None: + if exc: + raise FileNotFoundError( + f"Unable to find graphviz, look into paths such as {graphviz_dot}." + ) + return None + return os.path.join(p, "dot.exe") + # linux + return "dot" + + +def _run_subprocess( + args: List[str], + cwd: Optional[str] = None, +): + assert not isinstance( + args, str + ), "args should be a sequence of strings, not a string." + + p = subprocess.Popen( + args, + cwd=cwd, + shell=False, + env=os.environ, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + raise_exception = False + output = "" + while True: + output = p.stdout.readline().decode(errors="ignore") + if output == "" and p.poll() is not None: + break + if output: + if ( + "fatal error" in output + or "CMake Error" in output + or "gmake: ***" in output + or "): error C" in output + or ": error: " in output + ): + raise_exception = True + p.poll() + p.stdout.close() + if raise_exception: + raise RuntimeError( + "An error was found in the output. The build is stopped.\n{output}" + ) + return output + + +def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str: + """ + Run :epkg:`Graphviz`. + + :param filename: filename which contains the graph definition + :param image: output image + :param engine: *dot* or *neato* + :return: output of graphviz + """ + ext = os.path.splitext(image)[-1] + assert ext in { + ".png", + ".bmp", + ".fig", + ".gif", + ".ico", + ".jpg", + ".jpeg", + ".pdf", + ".ps", + ".svg", + ".vrml", + ".tif", + ".tiff", + ".wbmp", + }, f"Unexpected extension {ext!r} for {image!r}." + if sys.platform.startswith("win"): + bin_ = os.path.dirname(_find_graphviz_dot()) + # if bin not in os.environ["PATH"]: + # os.environ["PATH"] = os.environ["PATH"] + ";" + bin + exe = os.path.join(bin_, engine) + else: + exe = engine + if os.path.exists(image): + os.remove(image) + output = _run_subprocess([exe, f"-T{ext[1:]}", filename, "-o", image]) + assert os.path.exists(image), f"Graphviz failed due to {output}" + return output + + +def draw_graph_graphviz( + dot: str, + image: str, + engine: str = "dot", +) -> str: + """ + Draws a graph using :epkg:`Graphviz`. + + :param dot: dot graph + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :return: :epkg:`Graphviz` output or + the dot text if *image* is None + + The function creates a temporary file to store the dot file if *image* is not None. + """ + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(dot.encode("utf-8")) + fp.seek(0) + fp.close() + + filename = fp.name + assert os.path.exists( + filename + ), f"File {filename!r} cannot be created to store the graph." + out = _run_graphviz(filename, image, engine=engine) + assert os.path.exists( + image + ), f"Graphviz failed with no reason, {image!r} not found, output is {out}." + os.remove(filename) + return out + + +def plot_dot( + dot: str, + ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821 + engine: str = "dot", + figsize: Optional[Tuple[int, int]] = None, +) -> "matplotlib.axis.Axis": # noqa: F821 + """ + Draws a dot graph into a matplotlib graph. + + :param dot: dot graph + :param image: output image, None, just returns the output + :param engine: *dot* or *neato* + :param figsize: figsize of ax is None + :return: :epkg:`Graphviz` output or + the dot text if *image* is None + """ + if ax is None: + import matplotlib.pyplot as plt + + _, ax = plt.subplots(1, 1, figsize=figsize) + clean = True + else: + clean = False + + from PIL import Image + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp: + fp.close() + + draw_graph_graphviz(dot, fp.name, engine=engine) + img = np.asarray(Image.open(fp.name)) + os.remove(fp.name) + + ax.imshow(img) + + if clean: + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + ax.get_figure().tight_layout() + return ax From 8436ea447246d57e782866e4f099d93313db263e Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:39:56 +0100 Subject: [PATCH 2/6] improve versatility --- .gitignore | 1 + _unittests/ut_plotting/test_graphviz.py | 6 ++++++ onnx_array_api/plotting/graphviz_helper.py | 20 +++++++++++++------- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index ca8ce49..64d45d6 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ build/* *egg-info/* onnxruntime_profile* prof +test*.png _doc/sg_execution_times.rst _doc/auto_examples/* _doc/examples/_cache/* diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py index 3a61b59..374b85e 100644 --- a/_unittests/ut_plotting/test_graphviz.py +++ b/_unittests/ut_plotting/test_graphviz.py @@ -25,6 +25,12 @@ def test_draw_graph_graphviz(self): draw_graph_graphviz(dot, image=fout) self.assertExists(os.path.exists(fout)) + def test_draw_graph_graphviz_proto(self): + fout = "test_draw_graph_graphviz_proto.png" + dot = self._get_graph() + draw_graph_graphviz(dot, image=fout) + self.assertExists(os.path.exists(fout)) + def test_plot_dot(self): dot = to_dot(self._get_graph()) ax = plot_dot(dot) diff --git a/onnx_array_api/plotting/graphviz_helper.py b/onnx_array_api/plotting/graphviz_helper.py index 813b694..98e2de8 100644 --- a/onnx_array_api/plotting/graphviz_helper.py +++ b/onnx_array_api/plotting/graphviz_helper.py @@ -2,8 +2,9 @@ import subprocess import sys import tempfile -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np +from onnx import ModelProto def _find_in_PATH(prog: str) -> Optional[str]: @@ -139,14 +140,14 @@ def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str: def draw_graph_graphviz( - dot: str, + dot: Union[str, ModelProto], image: str, engine: str = "dot", ) -> str: """ Draws a graph using :epkg:`Graphviz`. - :param dot: dot graph + :param dot: dot graph or ModelProto :param image: output image, None, just returns the output :param engine: *dot* or *neato* :return: :epkg:`Graphviz` output or @@ -154,9 +155,14 @@ def draw_graph_graphviz( The function creates a temporary file to store the dot file if *image* is not None. """ + if isinstance(dot, ModelProto): + from .dot_plot import to_dot + + sdot = to_dot(dot) + else: + sdot = dot with tempfile.NamedTemporaryFile(delete=False) as fp: - fp.write(dot.encode("utf-8")) - fp.seek(0) + fp.write(sdot.encode("utf-8")) fp.close() filename = fp.name @@ -172,7 +178,7 @@ def draw_graph_graphviz( def plot_dot( - dot: str, + dot: Union[str, ModelProto], ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821 engine: str = "dot", figsize: Optional[Tuple[int, int]] = None, @@ -180,7 +186,7 @@ def plot_dot( """ Draws a dot graph into a matplotlib graph. - :param dot: dot graph + :param dot: dot graph or ModelProto :param image: output image, None, just returns the output :param engine: *dot* or *neato* :param figsize: figsize of ax is None From 1247949a7cafd4115470082b8478357444e434ec Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:41:09 +0100 Subject: [PATCH 3/6] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 39aaea9..dad0930 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`61`: adds function to plot onnx model as graphs * :pr:`60`: supports translation of local functions * :pr:`59`: add methods to update nodes in GraphAPI From abc4da725c42f3636de441e8021aa8aa75f500bd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 12:56:53 +0100 Subject: [PATCH 4/6] disable test when graphviz not installed --- _unittests/ut_plotting/test_graphviz.py | 13 ++++++++++++- onnx_array_api/ext_test_case.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py index 374b85e..d1c2545 100644 --- a/_unittests/ut_plotting/test_graphviz.py +++ b/_unittests/ut_plotting/test_graphviz.py @@ -1,7 +1,12 @@ import os import unittest import onnx.parser -from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.ext_test_case import ( + ExtTestCase, + skipci_apple, + skipif_ci_windows, + skipif_ci_apple, +) from onnx_array_api.plotting.dot_plot import to_dot from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot @@ -19,18 +24,24 @@ def _get_graph(cls): }""" ) + @skipif_ci_windows("graphviz not installed") + @skipif_ci_apple("graphviz not installed") def test_draw_graph_graphviz(self): fout = "test_draw_graph_graphviz.png" dot = to_dot(self._get_graph()) draw_graph_graphviz(dot, image=fout) self.assertExists(os.path.exists(fout)) + @skipif_ci_windows("graphviz not installed") + @skipif_ci_apple("graphviz not installed") def test_draw_graph_graphviz_proto(self): fout = "test_draw_graph_graphviz_proto.png" dot = self._get_graph() draw_graph_graphviz(dot, image=fout) self.assertExists(os.path.exists(fout)) + @skipif_ci_windows("graphviz not installed") + @skipif_ci_apple("graphviz not installed") def test_plot_dot(self): dot = to_dot(self._get_graph()) ax = plot_dot(dot) diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index 7555cb5..2f28a97 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -29,6 +29,16 @@ def skipif_ci_windows(msg) -> Callable: return lambda x: x +def skipif_ci_apple(msg) -> Callable: + """ + Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`. + """ + if is_apple() and is_azure(): + msg = f"Test does not work on azure pipeline (Apple). {msg}" + return unittest.skip(msg) + return lambda x: x + + def ignore_warnings(warns: List[Warning]) -> Callable: """ Catches warnings. From 65aebf1c6e02598826ba1ce81372f363bca66174 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 13:14:54 +0100 Subject: [PATCH 5/6] documentation --- _unittests/ut_plotting/test_graphviz.py | 1 - onnx_array_api/plotting/graphviz_helper.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py index d1c2545..420779e 100644 --- a/_unittests/ut_plotting/test_graphviz.py +++ b/_unittests/ut_plotting/test_graphviz.py @@ -3,7 +3,6 @@ import onnx.parser from onnx_array_api.ext_test_case import ( ExtTestCase, - skipci_apple, skipif_ci_windows, skipif_ci_apple, ) diff --git a/onnx_array_api/plotting/graphviz_helper.py b/onnx_array_api/plotting/graphviz_helper.py index 98e2de8..2dd93c2 100644 --- a/onnx_array_api/plotting/graphviz_helper.py +++ b/onnx_array_api/plotting/graphviz_helper.py @@ -192,6 +192,23 @@ def plot_dot( :param figsize: figsize of ax is None :return: :epkg:`Graphviz` output or the dot text if *image* is None + + .. plot:: + + import matplotlib.pyplot as plt + import onnx.parser + + model = onnx.parser.parse_model( + ''' + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(four, four) + }''') + ax = plot_dot(dot) + ax.set_title("Dummy graph") + plt.show() """ if ax is None: import matplotlib.pyplot as plt From 2e027034438209ab10e5b7af44f52db3d3de1d5f Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 13:43:56 +0100 Subject: [PATCH 6/6] add missing function --- onnx_array_api/ext_test_case.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py index 2f28a97..3c12e65 100644 --- a/onnx_array_api/ext_test_case.py +++ b/onnx_array_api/ext_test_case.py @@ -19,6 +19,10 @@ def is_windows() -> bool: return sys.platform == "win32" +def is_apple() -> bool: + return sys.platform == "darwin" + + def skipif_ci_windows(msg) -> Callable: """ Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.