From 46aa8860340fb090de72ec2785b47a232c13c686 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 13 Nov 2023 12:16:34 +0100 Subject: [PATCH 1/2] Adds command line to translate de model into code --- _unittests/ut_xrun_doc/test_command_lines1.py | 75 +++++++++++++++ onnx_array_api/__main__.py | 4 + onnx_array_api/_command_lines_parser.py | 94 +++++++++++++++++++ 3 files changed, 173 insertions(+) create mode 100644 _unittests/ut_xrun_doc/test_command_lines1.py create mode 100644 onnx_array_api/__main__.py create mode 100644 onnx_array_api/_command_lines_parser.py diff --git a/_unittests/ut_xrun_doc/test_command_lines1.py b/_unittests/ut_xrun_doc/test_command_lines1.py new file mode 100644 index 0000000..8aa17ee --- /dev/null +++ b/_unittests/ut_xrun_doc/test_command_lines1.py @@ -0,0 +1,75 @@ +import os +import tempfile +import unittest +from contextlib import redirect_stdout +from io import StringIO +from onnx import TensorProto +from onnx.helper import ( + make_graph, + make_model, + make_node, + make_opsetid, + make_tensor_value_info, +) +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api._command_lines_parser import ( + get_main_parser, + get_parser_translate, + main, +) + + +class TestCommandLines1(ExtTestCase): + def test_main_parser(self): + st = StringIO() + with redirect_stdout(st): + get_main_parser().print_help() + text = st.getvalue() + self.assertIn("translate", text) + + def test_parser_translate(self): + st = StringIO() + with redirect_stdout(st): + get_parser_translate().print_help() + text = st.getvalue() + self.assertIn("model", text) + + def test_command_translate(self): + X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6]) + Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None, None]) + graph = make_graph( + [ + make_node("Add", ["X", "Y"], ["res"]), + make_node("Cos", ["res"], ["Z"]), + ], + "g", + [X, Y], + [Z], + ) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)]) + + with tempfile.TemporaryDirectory() as root: + model_file = os.path.join(root, "model.onnx") + with open(model_file, "wb") as f: + f.write(onnx_model.SerializeToString()) + + args = ["translate", "-m", model_file] + st = StringIO() + with redirect_stdout(st): + main(args) + + code = st.getvalue() + self.assertIn("model = make_model(", code) + + args = ["translate", "-m", model_file, "-a", "light"] + st = StringIO() + with redirect_stdout(st): + main(args) + + code = st.getvalue() + self.assertIn("start(opset=", code) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/__main__.py b/onnx_array_api/__main__.py new file mode 100644 index 0000000..1fb5c0c --- /dev/null +++ b/onnx_array_api/__main__.py @@ -0,0 +1,4 @@ +from ._command_lines_parser import main + +if __name__ == "__main__": + main() diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py new file mode 100644 index 0000000..3860f18 --- /dev/null +++ b/onnx_array_api/_command_lines_parser.py @@ -0,0 +1,94 @@ +import sys +import onnx +from typing import Any, List, Optional +from argparse import ArgumentParser +from textwrap import dedent + + +def get_main_parser() -> ArgumentParser: + parser = ArgumentParser( + prog="onnx-array-api", + description="onnx-array-api main command line.", + epilog="Type 'python -m onnx_array_api --help' " + "to get help for a specific command.", + ) + parser.add_argument( + "cmd", + choices=["translate"], + help=dedent( + """ + Selects a command. + + 'translate' exports an onnx graph into a piece of code replicating it. + """ + ), + ) + return parser + + +def get_parser_translate() -> ArgumentParser: + parser = ArgumentParser( + prog="translate", + description=dedent( + """ + Translates an onnx model into a piece of code to replicate it. + The result is printed on the standard output. + """ + ), + epilog="This is mostly used to write unit tests without adding " + "an onnx file to the repository.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + help="onnx model to translate", + ) + parser.add_argument( + "-a", + "--api", + choices=["onnx", "light"], + default="onnx", + help="API to choose, API from onnx package or light API.", + ) + return parser + + +def _cmd_translate(argv: List[Any]): + from .light_api import translate + + parser = get_parser_translate() + args = parser.parse_args(argv[1:]) + onx = onnx.load(args.model) + code = translate(onx, api=args.api) + print(code) + + +def main(argv: Optional[List[Any]] = None): + fcts = dict(translate=_cmd_translate) + + if argv is None: + argv = sys.argv[1:] + if (len(argv) <= 1 and argv[0] not in fcts) or argv[-1] in ("--help", "-h"): + if len(argv) < 2: + parser = get_main_parser() + parser.parse_args(argv) + else: + parsers = dict(translate=get_parser_translate) + cmd = argv[0] + if cmd not in parsers: + raise ValueError( + f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}." + ) + parser = parsers[cmd]() + parser.parse_args(argv[1:]) + raise RuntimeError("The programme should have exited before.") + + cmd = argv[0] + if cmd in fcts: + fcts[cmd](argv) + else: + raise ValueError( + f"Unknown command {cmd!r}, use --help to get the list of known command." + ) From 7c3e39d849bf9bb70b0ee81737eefd7bb24b2cf0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 13 Nov 2023 12:17:29 +0100 Subject: [PATCH 2/2] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 055a05e..a8138bf 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.1.3 +++++ +* :pr:`49`: adds command line to export a model into code * :pr:`47`: extends export onnx to code to support inner API * :pr:`46`: adds an export to convert an onnx graph into light API code * :pr:`45`: fixes light API for operators with two outputs