diff --git a/.github/workflows/check-urls.yml b/.github/workflows/check-urls.yml
index 67d7731..d56adba 100644
--- a/.github/workflows/check-urls.yml
+++ b/.github/workflows/check-urls.yml
@@ -42,6 +42,6 @@ jobs:
print_all: false
timeout: 2
retry_count# : 2
- exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document
- exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/
+ exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx
+ exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx
# force_pass : true
diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml
index ba80296..3ad7c7c 100644
--- a/.github/workflows/documentation.yml
+++ b/.github/workflows/documentation.yml
@@ -21,7 +21,7 @@ jobs:
- uses: actions/setup-python@v4
with:
- python-version: '3.11'
+ python-version: '3.12'
- uses: tlylt/install-graphviz@v1
@@ -35,7 +35,7 @@ jobs:
run: python -m pip install -r requirements-dev.txt
- name: Cache pip
- uses: actions/cache@v2
+ uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }}
@@ -57,7 +57,7 @@ jobs:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Install
- run: python setup.py install
+ run: python -m pip install -e . -v
- name: Copy license, changelogs
run: |
@@ -83,6 +83,6 @@ jobs:
exit 1
fi
- - uses: actions/upload-artifact@v3
+ - uses: actions/upload-artifact@v4
with:
path: ./dist/html/**
diff --git a/.github/workflows/wheels-any.yml b/.github/workflows/wheels-any.yml
index c20a15d..4bf89c7 100644
--- a/.github/workflows/wheels-any.yml
+++ b/.github/workflows/wheels-any.yml
@@ -19,11 +19,11 @@ jobs:
- uses: actions/setup-python@v4
with:
- python-version: '3.11'
+ python-version: '3.12'
- name: build wheel
run: python -m pip wheel .
- - uses: actions/upload-artifact@v3
+ - uses: actions/upload-artifact@v4
with:
path: ./onnx_array_api*.whl
diff --git a/.gitignore b/.gitignore
index 303cd33..64d45d6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,8 @@ build/*
*egg-info/*
onnxruntime_profile*
prof
+test*.png
+_doc/sg_execution_times.rst
_doc/auto_examples/*
_doc/examples/_cache/*
_doc/examples/onnxruntime_profile*
diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst
index 9fb4ed8..8a91bbe 100644
--- a/CHANGELOGS.rst
+++ b/CHANGELOGS.rst
@@ -1,9 +1,36 @@
Change Logs
===========
+0.3.1
++++++
+
+* :pr:`100`: updates requirements, add 3.12
+* :pr:`96`: supports local functions in translator
+* :pr:`95`: improves translation to GraphBuilder
+
+0.3.0
++++++
+
+* :pr:`93`: fixes evaluator type in ``compare_onnx_execution``
+* :pr:`92`: avoids recursion errors in profiling
+* :pr:`87`: adds command line to replace contant by ConstantOfShape
+* :pr:`79`: first draft to export to GraphBuilder
+* :pr:`77`: supports ConcatOfShape and Slice with the light API
+
+0.2.0
++++++
+
+* :pr:`76`, :pr:`79`: add a mode to compare models without execution
+* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
+* :pr:`71`: adds tools to compare two onnx graphs
+* :pr:`61`: adds function to plot onnx model as graphs
+* :pr:`60`: supports translation of local functions
+* :pr:`59`: add methods to update nodes in GraphAPI
+
0.1.3
+++++
+* :pr:`57`: implements GraphBuilder
* :pr:`49`: adds command line to export a model into code
* :pr:`48`: support for subgraph in light API
* :pr:`47`: extends export onnx to code to support inner API
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..b4e1709
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,15 @@
+# Code of Conduct
+
+We are a community based on openness, as well as friendly and didactic discussions.
+
+We aspire to treat everybody equally, and value their contributions.
+
+Decisions are made based on technical merit and consensus.
+
+Code is not the only way to help the project. Reviewing pull requests,
+answering questions to help others on mailing lists or issues, organizing and
+teaching tutorials, working on the website, improving the documentation, are
+all priceless contributions.
+
+We abide by the principles of openness, respect, and consideration of others of
+the Python Software Foundation: https://www.python.org/psf/codeofconduct/
diff --git a/LICENSE.txt b/LICENSE.txt
index fa034ef..1a46a8e 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -1,4 +1,4 @@
-Copyright (c) 2023, Xavier Dupré
+Copyright (c) 2023-2025, Xavier Dupré
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/README.rst b/README.rst
index b24b73d..f7b387f 100644
--- a/README.rst
+++ b/README.rst
@@ -31,6 +31,14 @@ onnx-array-api: APIs to create ONNX Graphs
**onnx-array-api** implements APIs to create custom ONNX graphs.
The objective is to speed up the implementation of converter libraries.
+The library is released on
+`pypi/onnx-array-api `_
+and its documentation is published at
+`APIs to create ONNX Graphs `_.
+
+Numpy API
++++++++++
+
The first one matches **numpy API**.
It gives the user the ability to convert functions written
following the numpy API to convert that function into ONNX as
@@ -113,10 +121,15 @@ It supports eager mode as well:
l2_loss=[0.002]
[0.042]
+Light API
++++++++++
+
The second API or **Light API** tends to do every thing in one line.
+It is inspired from the `Reverse Polish Notation
+`_.
The euclidean distance looks like the following:
-::
+.. code-block:: python
import numpy as np
from onnx_array_api.light_api import start
@@ -138,7 +151,29 @@ The euclidean distance looks like the following:
.to_onnx()
)
-The library is released on
-`pypi/onnx-array-api `_
-and its documentation is published at
-`APIs to create ONNX Graphs `_.
+GraphBuilder API
+++++++++++++++++
+
+Almost every converting library (converting a machine learned model to ONNX) is implementing
+its own graph builder and customizes it for its needs.
+It handles some frequent tasks such as giving names to intermediate
+results, loading, saving onnx models. It can be used as well to extend an existing graph.
+
+.. code-block:: python
+
+ import numpy as np
+ from onnx_array_api.graph_api import GraphBuilder
+
+ g = GraphBuilder()
+ g.make_tensor_input("X", np.float32, (None, None))
+ g.make_tensor_input("Y", np.float32, (None, None))
+ r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
+ # it ensures the name is unique
+ init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
+ # converts the array to a tensor
+ r2 = g.make_node("Pow", [r1, init])
+ g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
+ # the user wants to choose the name
+ g.make_tensor_output("Z", np.float32, (None, None))
+
+ onx = g.to_onnx() # final conversion to onnx
diff --git a/_doc/api/graph_api.rst b/_doc/api/graph_api.rst
new file mode 100644
index 0000000..f618b7b
--- /dev/null
+++ b/_doc/api/graph_api.rst
@@ -0,0 +1,22 @@
+========================
+onnx_array_api.graph_api
+========================
+
+
+GraphBuilder
+============
+
+.. autoclass:: onnx_array_api.graph_api.GraphBuilder
+ :members:
+
+NodePattern
+===========
+
+.. autoclass:: onnx_array_api.graph_api.NodePattern
+ :members:
+
+OptimizationOptions
+===================
+
+.. autoclass:: onnx_array_api.graph_api.graph_builder.OptimizationOptions
+ :members:
diff --git a/_doc/api/index.rst b/_doc/api/index.rst
index 0f595f0..8cfe033 100644
--- a/_doc/api/index.rst
+++ b/_doc/api/index.rst
@@ -7,7 +7,9 @@ API
:maxdepth: 1
array_api
+ graph_api
light_api
+ translate_api
npx_core_api
npx_functions
npx_jit_eager
diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst
index 544b35f..e2a2d32 100644
--- a/_doc/api/light_api.rst
+++ b/_doc/api/light_api.rst
@@ -11,10 +11,10 @@ start
.. autofunction:: onnx_array_api.light_api.start
-translate
-+++++++++
+g
++
-.. autofunction:: onnx_array_api.light_api.translate
+.. autofunction:: onnx_array_api.light_api.g
Classes for the Light API
=========================
@@ -62,39 +62,6 @@ Vars
:members:
:inherited-members:
-Classes for the Translater
-==========================
-
-BaseEmitter
-+++++++++++
-
-.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter
- :members:
-
-Emitter
-+++++++
-
-.. autoclass:: onnx_array_api.light_api.emitter.Emitter
- :members:
-
-EventType
-+++++++++
-
-.. autoclass:: onnx_array_api.light_api.translate.EventType
- :members:
-
-InnerEmitter
-++++++++++++
-
-.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter
- :members:
-
-Translater
-++++++++++
-
-.. autoclass:: onnx_array_api.light_api.translate.Translater
- :members:
-
Available operators
===================
diff --git a/_doc/api/plotting.rst b/_doc/api/plotting.rst
index 830cc86..db6076c 100644
--- a/_doc/api/plotting.rst
+++ b/_doc/api/plotting.rst
@@ -6,6 +6,8 @@ Dot
.. autofunction:: onnx_array_api.plotting.dot_plot.to_dot
+.. autofunction:: onnx_array_api.plotting.graphviz_helper.plot_dot
+
Statistics
++++++++++
diff --git a/_doc/api/reference.rst b/_doc/api/reference.rst
index acbf90a..3b4ae7d 100644
--- a/_doc/api/reference.rst
+++ b/_doc/api/reference.rst
@@ -5,3 +5,33 @@ ExtendedReferenceEvaluator
++++++++++++++++++++++++++
.. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator
+ :members:
+
+ResultType
+++++++++++
+
+.. autoclass:: onnx_array_api.reference.ResultType
+ :members:
+
+ResultExecution
++++++++++++++++
+
+.. autoclass:: onnx_array_api.reference.ResultExecution
+ :members:
+
+YieldEvaluator
+++++++++++++++
+
+.. autoclass:: onnx_array_api.reference.YieldEvaluator
+ :members:
+
+DistanceExecution
++++++++++++++++++
+
+.. autoclass:: onnx_array_api.reference.DistanceExecution
+ :members:
+
+compare_onnx_execution
+++++++++++++++++++++++
+
+.. autofunction:: onnx_array_api.reference.compare_onnx_execution
diff --git a/_doc/api/tools.rst b/_doc/api/tools.rst
index ef161e0..e0450dc 100644
--- a/_doc/api/tools.rst
+++ b/_doc/api/tools.rst
@@ -6,6 +6,11 @@ Benchmark
.. autofunction:: onnx_array_api.ext_test_case.measure_time
+Manipulations
++++++++++++++
+
+.. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape
+
Examples
++++++++
diff --git a/_doc/api/translate_api.rst b/_doc/api/translate_api.rst
new file mode 100644
index 0000000..f2d90df
--- /dev/null
+++ b/_doc/api/translate_api.rst
@@ -0,0 +1,58 @@
+============================
+onnx_array_api.translate_api
+============================
+
+
+Main API
+========
+
+translate
++++++++++
+
+.. autofunction:: onnx_array_api.translate_api.translate
+
+make_helper
++++++++++++
+
+.. autofunction:: onnx_array_api.translate_api.make_helper.make_node_extended
+
+.. autofunction:: onnx_array_api.translate_api.make_helper.make_ref_attribute
+
+Classes for the Translater
+==========================
+
+BaseEmitter
++++++++++++
+
+.. autoclass:: onnx_array_api.translate_api.base_emitter.BaseEmitter
+ :members:
+
+EventType
++++++++++
+
+.. autoclass:: onnx_array_api.translate_api.base_emitter.EventType
+ :members:
+
+InnerEmitter
+++++++++++++
+
+.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter
+ :members:
+
+InnerEmitterShortInitializer
+++++++++++++++++++++++++++++
+
+.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitterShortInitializer
+ :members:
+
+LightEmitter
+++++++++++++
+
+.. autoclass:: onnx_array_api.translate_api.light_emitter.LightEmitter
+ :members:
+
+Translater
+++++++++++
+
+.. autoclass:: onnx_array_api.translate_api.translate.Translater
+ :members:
diff --git a/_doc/command_lines.rst b/_doc/command_lines.rst
new file mode 100644
index 0000000..38ca5f2
--- /dev/null
+++ b/_doc/command_lines.rst
@@ -0,0 +1,52 @@
+=============
+command lines
+=============
+
+compare
+=======
+
+The function convers an onnx file into some code.
+
+::
+
+ python -m compare -m1 model1.onnx -m2 model2.onnx -v 1
+
+Output example::
+
+ [compare_onnx_execution] got 2 inputs
+ [compare_onnx_execution] execute first model
+ [compare_onnx_execution] got 5 results
+ [compare_onnx_execution] execute second model
+ [compare_onnx_execution] got 5 results
+ [compare_onnx_execution] compute edit distance
+ [compare_onnx_execution] got 4 pairs
+ [compare_onnx_execution] done
+ = | INPUT float32 5x6 AAAA X | INPUT float32 5x6 AAAA X
+ = | INPUT float32 5x6 AAAA Y | INPUT float32 5x6 AAAA Y
+ = | RESULT float32 5x6 AABB Add res | RESULT float32 5x6 AABB Add res
+ = | RESULT float32 5x6 AAAA Cos Z | RESULT float32 5x6 AAAA Cos Z
+
+.. runpython::
+
+ from onnx_array_api._command_lines_parser import get_parser_compare
+ get_parser_compare().print_help()
+
+See function :func:`onnx_array_api.reference.compare_onnx_execution`.
+
+translate
+=========
+
+The function convers an onnx file into some code.
+
+::
+
+ python -m translate ...
+
+Output example::
+
+ not yet ready
+
+.. runpython::
+
+ from onnx_array_api._command_lines_parser import get_parser_translate
+ get_parser_translate().print_help()
diff --git a/_doc/conf.py b/_doc/conf.py
index d942076..eaf8eb1 100644
--- a/_doc/conf.py
+++ b/_doc/conf.py
@@ -35,7 +35,7 @@
source_suffix = ".rst"
master_doc = "index"
project = "onnx-array-api"
-copyright = "2023, Xavier Dupré"
+copyright = "2023-2024, Xavier Dupré"
author = "Xavier Dupré"
version = __version__
release = __version__
@@ -117,10 +117,11 @@
"ast": "https://docs.python.org/3/library/ast.html",
"cProfile.Profile": "https://docs.python.org/3/library/profile.html#profile.Profile",
"DOT": "https://graphviz.org/doc/info/lang.html",
+ "Graphviz": "https://graphviz.org/",
"inner API": "https://onnx.ai/onnx/intro/python.html",
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
"onnx": "https://onnx.ai/onnx/",
- "onnx-graphsurgeon": "https://docs.nvidia.com/deeplearning/tensorrt/onnx-graphsurgeon/docs/index.html",
+ "onnx-graphsurgeon": "https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon",
"onnx.helper": "https://onnx.ai/onnx/api/helper.html",
"ONNX": "https://onnx.ai/",
"ONNX Operators": "https://onnx.ai/onnx/operators/",
@@ -145,11 +146,9 @@
"torch.onnx": "https://pytorch.org/docs/stable/onnx.html",
#
"C_OrtValue": (
- "http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/"
- "api/onnxruntime_python/ortvalue.html#c-class-ortvalue-or-c-ortvalue"
+ "https://onnxruntime.ai/docs/api/csharp/api/Microsoft.ML.OnnxRuntime.OrtValue.html"
),
"OrtValue": (
- "http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/"
- "api/onnxruntime_python/ortvalue.html#onnxruntime.OrtValue"
+ "https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.OrtValue"
),
}
diff --git a/_doc/examples/plot_benchmark_rf.py b/_doc/examples/plot_benchmark_rf.py
index 8b62e3d..c1ce486 100644
--- a/_doc/examples/plot_benchmark_rf.py
+++ b/_doc/examples/plot_benchmark_rf.py
@@ -12,6 +12,7 @@
import and registration of necessary converters
++++++++++++++++++++++++++++++++++++++++++++++++
"""
+
import pickle
import os
import time
@@ -21,8 +22,6 @@
import numpy
import pandas
from lightgbm import LGBMRegressor
-from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm
-from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost
from onnxruntime import InferenceSession, SessionOptions
from psutil import cpu_count
from sphinx_runpython.runpython import run_cmd
@@ -32,14 +31,16 @@
from sklearn.ensemble import RandomForestRegressor
from tqdm import tqdm
from xgboost import XGBRegressor
+from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost
def skl2onnx_convert_lightgbm(scope, operator, container):
+ from onnxmltools.convert.lightgbm.operator_converters.LightGbm import (
+ convert_lightgbm,
+ )
+
options = scope.get_options(operator.raw_operator)
- if "split" in options:
- operator.split = options["split"]
- else:
- operator.split = None
+ operator.split = options.get("split", None)
convert_lightgbm(scope, operator, container)
@@ -99,7 +100,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
:return: number of runs, sum of the time, average, median
"""
times = []
- for n in range(repeat):
+ for _n in range(repeat):
perf = time.perf_counter()
fct(X)
delta = time.perf_counter() - perf
@@ -237,7 +238,10 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
# onnxruntime
bar.set_description(f"J={n_j} E={n_estimators} D={max_depth} predictO")
r, t, mean, med = measure_inference(
- lambda x: sess.run(None, {"X": x}), X, repeat=repeat, max_time=max_time
+ lambda x, sess=sess: sess.run(None, {"X": x}),
+ X,
+ repeat=repeat,
+ max_time=max_time,
)
o2 = obs.copy()
o2.update(dict(avg=mean, med=med, n_runs=r, ttime=t, name="ort_"))
diff --git a/_doc/examples/plot_onnx_diff.py b/_doc/examples/plot_onnx_diff.py
new file mode 100644
index 0000000..7b6ecdf
--- /dev/null
+++ b/_doc/examples/plot_onnx_diff.py
@@ -0,0 +1,69 @@
+"""
+
+.. _l-onnx-diff-example:
+
+Compares the conversions of the same model with different options
+=================================================================
+
+The script compares two onnx models obtained with the same trained
+scikit-learn models but converted with different options.
+
+A model
++++++++
+"""
+
+from sklearn.mixture import GaussianMixture
+from sklearn.datasets import load_iris
+from sklearn.model_selection import train_test_split
+from skl2onnx import to_onnx
+from onnx_array_api.reference import compare_onnx_execution
+from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+
+data = load_iris()
+X_train, X_test = train_test_split(data.data)
+model = GaussianMixture()
+model.fit(X_train)
+
+#################################
+# Conversion to onnx
+# ++++++++++++++++++
+
+onx = to_onnx(
+ model, X_train[:1], options={id(model): {"score_samples": True}}, target_opset=12
+)
+
+print(onnx_simple_text_plot(onx))
+
+##################################
+# Conversion to onnx without ReduceLogSumExp
+# ++++++++++++++++++++++++++++++++++++++++++
+
+onx2 = to_onnx(
+ model,
+ X_train[:1],
+ options={id(model): {"score_samples": True}},
+ black_op={"ReduceLogSumExp"},
+ target_opset=12,
+)
+
+print(onnx_simple_text_plot(onx2))
+
+
+#############################################
+# Differences
+# +++++++++++
+#
+# Function :func:`onnx_array_api.reference.compare_onnx_execution`
+# compares the intermediate results of two onnx models. Then it finds
+# the best alignmet between the two models using an edit distance.
+
+res1, res2, align, dc = compare_onnx_execution(onx, onx2, verbose=1)
+print("------------")
+text = dc.to_str(res1, res2, align)
+print(text)
+
+###############################
+# See :ref:`l-long-output-compare_onnx_execution` for a better view.
+# The display shows that ReduceSumSquare was replaced by Mul + ReduceSum,
+# and ReduceLogSumExp by ReduceMax + Sub + Exp + Log + Add.
diff --git a/_doc/examples/plot_onnxruntime.py b/_doc/examples/plot_onnxruntime.py
index fcace3e..0aba6ac 100644
--- a/_doc/examples/plot_onnxruntime.py
+++ b/_doc/examples/plot_onnxruntime.py
@@ -87,14 +87,14 @@ def loop(n=1000):
x = np.random.randn(n, 2).astype(np.float32)
y = np.random.randn(n, 2).astype(np.float32)
- obs = measure_time(lambda: myloss(x, y))
+ obs = measure_time(lambda x=x, y=y: myloss(x, y))
obs["name"] = "numpy"
obs["n"] = n
data.append(obs)
xort = OrtTensor.from_array(x)
yort = OrtTensor.from_array(y)
- obs = measure_time(lambda: ort_myloss(xort, yort))
+ obs = measure_time(lambda xort=xort, yort=yort: ort_myloss(xort, yort))
obs["name"] = "ort"
obs["n"] = n
data.append(obs)
diff --git a/_doc/examples/plot_optimization.py b/_doc/examples/plot_optimization.py
index 466fac0..c78419b 100644
--- a/_doc/examples/plot_optimization.py
+++ b/_doc/examples/plot_optimization.py
@@ -15,6 +15,7 @@
Optimize a model with onnxruntime
+++++++++++++++++++++++++++++++++
"""
+
import os
from pprint import pprint
import numpy
diff --git a/_doc/examples/plot_profiling.py b/_doc/examples/plot_profiling.py
index 7a61b68..201de95 100644
--- a/_doc/examples/plot_profiling.py
+++ b/_doc/examples/plot_profiling.py
@@ -15,6 +15,7 @@
Optimize a model with onnxruntime
+++++++++++++++++++++++++++++++++
"""
+
import os
import numpy
import matplotlib.pyplot as plt
diff --git a/_doc/index.rst b/_doc/index.rst
index 93ca000..9bdc4e2 100644
--- a/_doc/index.rst
+++ b/_doc/index.rst
@@ -36,6 +36,7 @@ The objective is to speed up the implementation of converter libraries.
tutorial/index
api/index
tech/index
+ command_lines
auto_examples/index
.. toctree::
@@ -44,12 +45,85 @@ The objective is to speed up the implementation of converter libraries.
CHANGELOGS
license
+ long_outputs
+
+Sources available on
+`github/onnx-array-api `_.
+
+GraphBuilder API
+++++++++++++++++
+
+Almost every converting library (converting a machine learned model to ONNX) is implementing
+its own graph builder and customizes it for its needs.
+It handles some frequent tasks such as giving names to intermediate
+results, loading, saving onnx models. It can be used as well to extend an existing graph.
+See :ref:`l-graph-api`.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.graph_api import GraphBuilder
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ g = GraphBuilder()
+ g.make_tensor_input("X", np.float32, (None, None))
+ g.make_tensor_input("Y", np.float32, (None, None))
+ r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
+ # it ensures the name is unique
+ init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
+ # converts the array to a tensor
+ r2 = g.make_node("Pow", [r1, init])
+ g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
+ # the user wants to choose the name
+ g.make_tensor_output("Z", np.float32, (None, None))
+
+ onx = g.to_onnx() # final conversion to onnx
+
+ print(onnx_simple_text_plot(onx))
+
+Light API
++++++++++
+
+The syntax is inspired from the
+`Reverse Polish Notation `_.
+This kind of API is easy to use to build new graphs,
+less easy to extend an existing graph. See :ref:`l-light-api`.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.light_api import start
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ model = (
+ start()
+ .vin("X")
+ .vin("Y")
+ .bring("X", "Y")
+ .Sub()
+ .rename("dxy")
+ .cst(np.array([2], dtype=np.int64), "two")
+ .bring("dxy", "two")
+ .Pow()
+ .ReduceSum()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+
+ print(onnx_simple_text_plot(model))
Numpy API
+++++++++
-Sources available on
-`github/onnx-array-api `_.
+Writing ONNX graphs requires to know ONNX syntax unless
+it is possible to reuse an existing syntax such as :epkg:`numpy`.
+This is what this API is doing.
+This kind of API is easy to use to build new graphs,
+almost impossible to use to extend new graphs as it usually requires
+to know onnx for that. See :ref:`l-numpy-api-onnx`.
.. runpython::
:showcode:
@@ -110,36 +184,10 @@ Sources available on
res = jitted_myloss(x, y)
print(to_dot(jitted_myloss.get_onnx()))
-Light API
-+++++++++
-
-.. runpython::
- :showcode:
-
- import numpy as np
- from onnx_array_api.light_api import start
- from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
-
- model = (
- start()
- .vin("X")
- .vin("Y")
- .bring("X", "Y")
- .Sub()
- .rename("dxy")
- .cst(np.array([2], dtype=np.int64), "two")
- .bring("dxy", "two")
- .Pow()
- .ReduceSum()
- .rename("Z")
- .vout()
- .to_onnx()
- )
-
- print(onnx_simple_text_plot(model))
-
-
Older versions
++++++++++++++
+* `0.3.0 <../v0.3.0/index.html>`_
+* `0.2.0 <../v0.2.0/index.html>`_
+* `0.1.3 <../v0.1.3/index.html>`_
* `0.1.2 <../v0.1.2/index.html>`_
diff --git a/_doc/long_outputs.rst b/_doc/long_outputs.rst
new file mode 100644
index 0000000..745382b
--- /dev/null
+++ b/_doc/long_outputs.rst
@@ -0,0 +1,47 @@
+:hide-toc:
+
+==========================
+Long outputs uneasy to see
+==========================
+
+onnx
+====
+
+.. _l-long-output-compare_onnx_execution:
+
+onnx_array_api.reference.compare_onnx_execution
++++++++++++++++++++++++++++++++++++++++++++++++
+
+From example :ref:`l-onnx-diff-example` for function
+:func:`onnx_array_api.reference.compare_onnx_execution`.
+See also `raw rendering `_.
+
+::
+
+ 1 = | INITIA float64 1 HAAA Ad_Addcst | INITIA float64 1 HAAA Ad_Addcst
+ 2 = | INITIA float64 4x4 ADZF Ge_Gemmcst | INITIA float64 4x4 ADZF Ge_Gemmcst
+ 3 = | INITIA float64 4 USEA Ge_Gemmcst1 | INITIA float64 4 USEA Ge_Gemmcst1
+ 4 = | INITIA float64 1 AAAA Mu_Mulcst | INITIA float64 1 AAAA Mu_Mulcst
+ 5 = | INITIA float64 1 DAAA Ad_Addcst1 | INITIA float64 1 DAAA Ad_Addcst1
+ 6 = | INITIA float64 1 AAAA Ad_Addcst2 | INITIA float64 1 AAAA Ad_Addcst2
+ 7 = | INPUT float64 1x4 AAAA X | INPUT float64 1x4 AAAA X
+ 8 = | RESULT float64 1x4 UTFC Gemm Ge_Y0 | RESULT float64 1x4 UTFC Gemm Ge_Y0
+ 9 + | | RESULT float64 1x4 TIEG Mul Mu_C01
+ 10 ~ | RESULT float64 1x1 NAAA ReduceSumS Re_reduced0 | RESULT float64 1x1 NAAA ReduceSum Re_reduced0
+ 11 = | RESULT float64 1x1 NAAA Concat Co_concat_re | RESULT float64 1x1 NAAA Concat Co_concat_re
+ 12 = | RESULT float64 1x1 UAAA Add Ad_C02 | RESULT float64 1x1 UAAA Add Ad_C02
+ 13 = | RESULT float64 1x1 DAAA Mul Mu_C0 | RESULT float64 1x1 DAAA Mul Mu_C0
+ 14 = | RESULT float64 1x1 GAAA Add Ad_C01 | RESULT float64 1x1 GAAA Add Ad_C01
+ 15 = | RESULT float64 1x1 GAAA Add Ad_C0 | RESULT float64 1x1 GAAA Add Ad_C0
+ 16 = | RESULT int64 1x1 AAAA ArgMax label | RESULT int64 1x1 AAAA ArgMax label
+ 17 + | | RESULT float64 1x1 GAAA ReduceMax Re_reduced03
+ 18 + | | RESULT float64 1x1 AAAA Sub Su_C01
+ 19 + | | RESULT float64 1x1 BAAA Exp Ex_output0
+ 20 + | | RESULT float64 1x1 BAAA ReduceSum Re_reduced02
+ 21 + | | RESULT float64 1x1 AAAA Log Lo_output0
+ 22 ~ | RESULT float64 1x1 GAAA ReduceLogS score_sample | RESULT float64 1x1 GAAA Add score_sample
+ 23 = | RESULT float64 1x1 AAAA Sub Su_C0 | RESULT float64 1x1 AAAA Sub Su_C0
+ 24 = | RESULT float64 1x1 BAAA Exp probabilitie | RESULT float64 1x1 BAAA Exp probabilitie
+ 25 = | OUTPUT int64 1x1 AAAA label | OUTPUT int64 1x1 AAAA label
+ 26 = | OUTPUT float64 1x1 BAAA probabilitie | OUTPUT float64 1x1 BAAA probabilitie
+ 27 = | OUTPUT float64 1x1 GAAA score_sample | OUTPUT float64 1x1 GAAA score_sample
diff --git a/_doc/tutorial/graph_api.rst b/_doc/tutorial/graph_api.rst
new file mode 100644
index 0000000..b373cc3
--- /dev/null
+++ b/_doc/tutorial/graph_api.rst
@@ -0,0 +1,59 @@
+.. _l-graph-api:
+
+=================================
+GraphBuilder: common API for ONNX
+=================================
+
+This is a very common way to build ONNX graph. There are some
+annoying steps while building an ONNX graph. The first one is to
+give unique names to every intermediate result in the graph. The second
+is the conversion from numpy arrays to onnx tensors. A *graph builder*,
+here implemented by class
+:class:`GraphBuilder `
+usually makes these two frequent tasks easier.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.graph_api import GraphBuilder
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ g = GraphBuilder()
+ g.make_tensor_input("X", np.float32, (None, None))
+ g.make_tensor_input("Y", np.float32, (None, None))
+ r1 = g.make_node("Sub", ["X", "Y"]) # the name given to the output is given by the class,
+ # it ensures the name is unique
+ init = g.make_initializer(np.array([2], dtype=np.int64)) # the class automatically
+ # converts the array to a tensor
+ r2 = g.make_node("Pow", [r1, init])
+ g.make_node("ReduceSum", [r2], outputs=["Z"]) # the output name is given because
+ # the user wants to choose the name
+ g.make_tensor_output("Z", np.float32, (None, None))
+
+ onx = g.to_onnx() # final conversion to onnx
+
+ print(onnx_simple_text_plot(onx))
+
+A more simple versions of the same code to produce the same graph.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.graph_api import GraphBuilder
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ g = GraphBuilder()
+ g.make_tensor_input("X", np.float32, (None, None))
+ g.make_tensor_input("Y", np.float32, (None, None))
+ r1 = g.op.Sub("X", "Y") # the method name indicates which operator to use,
+ # this can be used when there is no ambiguity about the
+ # number of outputs
+ r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
+ g.op.ReduceSum(r2, outputs=["Z"]) # the still wants the user to specify the name
+ g.make_tensor_output("Z", np.float32, (None, None))
+
+ onx = g.to_onnx()
+
+ print(onnx_simple_text_plot(onx))
diff --git a/_doc/tutorial/index.rst b/_doc/tutorial/index.rst
index e3ca8d7..9fcc557 100644
--- a/_doc/tutorial/index.rst
+++ b/_doc/tutorial/index.rst
@@ -7,6 +7,8 @@ Tutorial
:maxdepth: 1
onnx_api
+ graph_api
light_api
numpy_api
+ tools
benchmarks
diff --git a/_doc/tutorial/onnx_api.rst b/_doc/tutorial/onnx_api.rst
index f27eb05..2b673fb 100644
--- a/_doc/tutorial/onnx_api.rst
+++ b/_doc/tutorial/onnx_api.rst
@@ -71,7 +71,11 @@ the true implementation would be the following.
n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"])
n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name])
graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two])
- model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", opset)])
+ model = oh.make_model(
+ graph,
+ opset_imports=[oh.make_opsetid("", opset)],
+ ir_version=9,
+ )
return model
@@ -584,37 +588,31 @@ The second part modifies it.
onnx.save(gs.export_onnx(graph), "modified.onnx")
-numpy API for onnx
-++++++++++++++++++
+Graph Builder API
++++++++++++++++++
-See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs
-by using numpy API. If a function is defined only with numpy,
-it should be possible to use the exact same code to create the
-corresponding onnx graph. That's what this API tries to achieve.
-It works with the exception of control flow. In that case, the function
-produces different onnx graphs depending on the execution path.
+See :ref:`l-graph-api`. This API is very similar to what *skl2onnx* implements.
+It is still about adding nodes to a graph but some tasks are automated such as
+naming the results or converting constants to onnx classes.
.. runpython::
:showcode:
import numpy as np
- from onnx_array_api.npx import jit_onnx
+ from onnx_array_api.graph_api import GraphBuilder
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
- def l2_loss(x, y):
- return ((x - y) ** 2).sum(keepdims=1)
-
- jitted_myloss = jit_onnx(l2_loss)
- dummy = np.array([0], dtype=np.float32)
-
- # The function is executed. Only then a onnx graph is created.
- # One is created depending on the input type.
- jitted_myloss(dummy, dummy)
+ g = GraphBuilder()
+ g.make_tensor_input("X", np.float32, (None, None))
+ g.make_tensor_input("Y", np.float32, (None, None))
+ r1 = g.op.Sub("X", "Y")
+ r2 = g.op.Pow(r1, np.array([2], dtype=np.int64))
+ g.op.ReduceSum(r2, outputs=["Z"])
+ g.make_tensor_output("Z", np.float32, (None, None))
+
+ onx = g.to_onnx()
- # get_onnx only works if it was executed once or at least with
- # the same input type
- model = jitted_myloss.get_onnx()
- print(onnx_simple_text_plot(model))
+ print(onnx_simple_text_plot(onx))
Light API
+++++++++
@@ -647,3 +645,35 @@ There is no eager mode.
)
print(onnx_simple_text_plot(model))
+
+numpy API for onnx
+++++++++++++++++++
+
+See :ref:`l-numpy-api-onnx`. This API was introduced to create graphs
+by using numpy API. If a function is defined only with numpy,
+it should be possible to use the exact same code to create the
+corresponding onnx graph. That's what this API tries to achieve.
+It works with the exception of control flow. In that case, the function
+produces different onnx graphs depending on the execution path.
+
+.. runpython::
+ :showcode:
+
+ import numpy as np
+ from onnx_array_api.npx import jit_onnx
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+
+ def l2_loss(x, y):
+ return ((x - y) ** 2).sum(keepdims=1)
+
+ jitted_myloss = jit_onnx(l2_loss)
+ dummy = np.array([0], dtype=np.float32)
+
+ # The function is executed. Only then a onnx graph is created.
+ # One is created depending on the input type.
+ jitted_myloss(dummy, dummy)
+
+ # get_onnx only works if it was executed once or at least with
+ # the same input type
+ model = jitted_myloss.get_onnx()
+ print(onnx_simple_text_plot(model))
diff --git a/_doc/tutorial/tools.rst b/_doc/tutorial/tools.rst
new file mode 100644
index 0000000..fe673f7
--- /dev/null
+++ b/_doc/tutorial/tools.rst
@@ -0,0 +1,20 @@
+=====
+Tools
+=====
+
+Some of useful tools.
+
+Text representation
+===================
+
+Plotting a graph is great but difficult to read when
+the graph is big and it is slow.
+:func:`onnx_array_api.plotting.text_plot.onnx_simple_text_plot`
+prints out a text representation.
+
+Differences between two models
+==============================
+
+How to understand the differences between two models
+assuming they are producing the same outputs?
+Example :ref:`l-onnx-diff-example` shows how to do it.
diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt
index bf91e86..5deb50e 100644
--- a/_unittests/onnx-numpy-skips.txt
+++ b/_unittests/onnx-numpy-skips.txt
@@ -1,9 +1,17 @@
# API failures
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
# uses __setitem__
+array_api_tests/test_creation_functions.py::test_arange
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
+array_api_tests/test_creation_functions.py::test_eye
+array_api_tests/test_creation_functions.py::test_full
+array_api_tests/test_creation_functions.py::test_full_like
+array_api_tests/test_creation_functions.py::test_ones
+array_api_tests/test_creation_functions.py::test_ones_like
+array_api_tests/test_creation_functions.py::test_zeros
+array_api_tests/test_creation_functions.py::test_zeros_like
# fails to precision issue
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
diff --git a/_unittests/ut_array_api/test_hypothesis_array_api.py b/_unittests/ut_array_api/test_hypothesis_array_api.py
index 95b1447..f55d230 100644
--- a/_unittests/ut_array_api/test_hypothesis_array_api.py
+++ b/_unittests/ut_array_api/test_hypothesis_array_api.py
@@ -1,7 +1,7 @@
import unittest
-import warnings
from os import getenv
from functools import reduce
+import packaging.version as pv
import numpy as np
from operator import mul
from hypothesis import given
@@ -44,9 +44,7 @@ class TestHypothesisArraysApis(ExtTestCase):
@classmethod
def setUpClass(cls):
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- from numpy import array_api as xp
+ import array_api_strict as xp
api_version = getenv(
"ARRAY_API_TESTS_VERSION",
@@ -63,6 +61,9 @@ def test_strategies(self):
self.assertNotEmpty(self.xps)
self.assertNotEmpty(self.onxps)
+ @unittest.skipIf(
+ pv.Version(np.__version__) >= pv.Version("2.0"), reason="abandonned"
+ )
def test_scalar_strategies(self):
dtypes = dict(
integer_dtypes=self.xps.integer_dtypes(),
@@ -139,6 +140,9 @@ def fctonx(x, kw):
fctonx()
self.assertEqual(len(args_onxp), len(args_np))
+ @unittest.skipIf(
+ pv.Version(np.__version__) >= pv.Version("2.0"), reason="abandonned"
+ )
def test_square_sizes_strategies(self):
dtypes = dict(
integer_dtypes=self.xps.integer_dtypes(),
diff --git a/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx b/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx
new file mode 100644
index 0000000..77ba377
Binary files /dev/null and b/_unittests/ut_graph_api/data/debug_7951-CPUep.0.onnx differ
diff --git a/_unittests/ut_graph_api/test_graph_builder.py b/_unittests/ut_graph_api/test_graph_builder.py
new file mode 100644
index 0000000..9e6229b
--- /dev/null
+++ b/_unittests/ut_graph_api/test_graph_builder.py
@@ -0,0 +1,443 @@
+import contextlib
+import io
+import unittest
+import numpy as np
+import onnx
+from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_apple
+from onnx_array_api.graph_api.graph_builder import GraphBuilder, OptimizationOptions
+from onnx_array_api.reference import (
+ from_array_extended,
+ ExtendedReferenceEvaluator as ReferenceEvaluator,
+)
+
+
+class TestGraphBuilder(ExtTestCase):
+ def call_optimizer(self, onx):
+ gr = GraphBuilder(onx)
+ gr.remove_unused()
+ return gr.to_onnx()
+
+ def test_remove_unused_nodes(self):
+ model = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, x)
+ }"""
+ )
+ onx = self.call_optimizer(model)
+ self.assertEqual(len(onx.graph.node), 1)
+ self.assertEqual(onx.graph.node[0].op_type, "Mul")
+
+ def test_initializers(self):
+ model = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z)
+ {
+ four = Add(two, two)
+ z = Mul(x, x)
+ }"""
+ )
+ self.assertEqual(len(model.graph.initializer), 1)
+ onx = self.call_optimizer(model)
+ self.assertEqual(len(onx.graph.node), 1)
+ self.assertEqual(onx.graph.node[0].op_type, "Mul")
+ self.assertEqual(len(onx.graph.initializer), 0)
+
+ def test_keep_unused_outputs(self):
+ model = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[M] z) {
+ w1, w2, w3 = Split (x)
+ z = Mul(w3, w3)
+ }"""
+ )
+ onx = self.call_optimizer(model)
+ self.assertEqual(len(onx.graph.node), 2)
+ self.assertEqual(onx.graph.node[0].op_type, "Split")
+
+ def test_exc(self):
+ self.assertRaise(lambda: GraphBuilder([]), NotImplementedError)
+
+ def test_simple(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+ onx = g.to_onnx()
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_simple_big(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (30, 40)
+ w = np.random.randn(*shape).astype(np.float32)
+
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (30, 1))
+ onx = g.to_onnx()
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ @skipif_ci_apple("libomp is missing")
+ def test_constant_folding(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+
+ g.constant_folding()
+
+ onx = g.to_onnx()
+ node_types = [n.op_type for n in onx.graph.node]
+ self.assertNotIn("Transpose", node_types)
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ @skipif_ci_apple("libomp is missing")
+ def test_constant_folding2(self):
+ g = GraphBuilder(
+ optimization_options=OptimizationOptions(constant_folding=True)
+ )
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ cst = g.get_constant(weight)
+ self.assertEqualArray(w, cst)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+
+ g.optimize()
+
+ onx = g.to_onnx()
+ node_types = [n.op_type for n in onx.graph.node]
+ self.assertNotIn("Transpose", node_types)
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_remove_identity(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.Identity(g.op.MatMul(x, transposed))
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+
+ g.remove_identity_nodes()
+
+ onx = g.to_onnx()
+ node_types = [n.op_type for n in onx.graph.node]
+ self.assertNotIn("Identity", node_types)
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_remove_identity_input(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+ x = g.make_tensor_input("X", np.float32, shape)
+ x = g.op.Identity(x)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+
+ g.remove_identity_nodes()
+
+ onx = g.to_onnx()
+ node_types = [n.op_type for n in onx.graph.node]
+ self.assertNotIn("Identity", node_types)
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_remove_identity_output(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ r = g.op.Reshape(res, one)
+ g.op.Identity(r, outputs=["y"])
+ g.make_tensor_output("y", np.float32, (10, 1))
+
+ g.remove_identity_nodes()
+
+ onx = g.to_onnx()
+ node_types = [n.op_type for n in onx.graph.node]
+ self.assertNotIn("Identity", node_types)
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_remove_unused_nodes_simple(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ cst = g.make_initializer(np.array([2], dtype=np.float32))
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ g.op.Add(res, cst)
+ g.op.Reshape(res, one, outputs=["y"])
+ g.make_tensor_output("y", np.float32, (10, 1))
+
+ g.remove_identity_nodes()
+
+ onx = g.to_onnx()
+ node_types = [n.op_type for n in onx.graph.node]
+ self.assertNotIn("Add", node_types)
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ @skipif_ci_apple("libomp is missing")
+ def test_constant_array(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+
+ x = g.make_tensor_input("X", np.float32, shape)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ res = g.op.MatMul(x, w.T)
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+ onx = g.to_onnx()
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ @skipif_ci_apple("libomp is missing")
+ def test_constant_array_2(self):
+ with contextlib.redirect_stdout(io.StringIO()):
+ g = GraphBuilder(verbose=10)
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+
+ x = g.make_tensor_input("X", np.float32, shape)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ opc = g.op.Constant(value=from_array_extended(w.T))
+ res = g.op.MatMul(x, opc)
+ g.op.Reshape(res, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+ self.assertTrue(g.has_shape("X"))
+ self.assertTrue(g.has_type("X"))
+ self.assertEqual(g.get_type("X"), 1)
+ self.assertEqual(g.get_shape("X"), (10, 4))
+ self.assertEqual(g.rank("X"), 2)
+ onx = g.to_onnx()
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1))
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_get_type(self):
+ g = GraphBuilder()
+ self.assertEqual(g._get_type(np.float32), onnx.TensorProto.FLOAT)
+ self.assertEqual(g._get_type(np.int64), onnx.TensorProto.INT64)
+ self.assertEqual(g._get_type(None), onnx.TensorProto.UNDEFINED)
+
+ def test_make_nodes_prefix(self):
+ g1 = GraphBuilder()
+ g1.make_tensor_input("X", np.float32, shape=None)
+ g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"])
+ g1.make_tensor_output("y", np.float32, shape=None)
+
+ g = GraphBuilder()
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ res2 = g.make_nodes(g1, [res], ["k"], prefix="J")
+ g.op.Reshape(res2, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+ onx = g.to_onnx()
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1)) + 1
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_make_nodes_noprefix(self):
+ g1 = GraphBuilder()
+ g1.make_tensor_input("X", np.float32, shape=None)
+ g1.op.Add("X", np.array([1], dtype=np.float32), outputs=["y"])
+ g1.make_tensor_output("y", np.float32, shape=None)
+
+ g = GraphBuilder()
+
+ shape = (10, 4)
+ w = np.random.randn(*shape).astype(np.float32)
+
+ x = g.make_tensor_input("X", np.float32, shape)
+ weight = g.make_initializer(w)
+ one = g.make_initializer(np.array([-1, 1], dtype=np.int64))
+ transposed = g.make_node("Transpose", [weight], perm=[1, 0])
+ res = g.op.MatMul(x, transposed)
+ res2 = g.make_nodes(g1, [res], ["k"])
+ g.op.Reshape(res2, one, outputs="y")
+ g.make_tensor_output("y", np.float32, (10, 1))
+ onx = g.to_onnx()
+ ref = ReferenceEvaluator(onx)
+ x = np.random.randn(*shape).astype(np.float32)
+ expected = (x @ w.T).reshape((-1, 1)) + 1
+ feeds = {"X": x}
+ got = ref.run(None, feeds)
+ self.assertEqualArray(expected, got[0])
+
+ def test_node_pattern(self):
+ model = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, four)
+ }"""
+ )
+ gr = GraphBuilder(model)
+ p = gr.np(index=0)
+ r = repr(p)
+ self.assertEqual("NodePattern(index=0, op_type=None, name=None)", r)
+
+ def test_update_node_attribute(self):
+ model = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, four)
+ }"""
+ )
+ gr = GraphBuilder(model)
+ self.assertEqual(len(gr.nodes), 3)
+ m = gr.update_attribute(gr.np(op_type="Constant"), value_float=float(1))
+ self.assertEqual(m, 1)
+ self.assertEqual(len(gr.nodes), 3)
+ onx = gr.to_onnx()
+ self.assertEqual(len(onx.graph.node), 3)
+ node = onx.graph.node[0]
+ self.assertIn("f: 1", str(node))
+
+ def test_delete_node_attribute(self):
+ model = onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, four)
+ }"""
+ )
+ gr = GraphBuilder(model)
+ self.assertEqual(len(gr.nodes), 3)
+ m = gr.update_attribute(
+ gr.np(op_type="Constant"), value_float=gr.DELETE, value_int=1
+ )
+ self.assertEqual(m, 1)
+ self.assertEqual(len(gr.nodes), 3)
+ onx = gr.to_onnx()
+ self.assertEqual(len(onx.graph.node), 3)
+ node = onx.graph.node[0]
+ self.assertNotIn('name: "value_float"', str(node))
+ self.assertIn("i: 1", str(node))
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/_unittests/ut_graph_api/test_graph_builder_optim.py b/_unittests/ut_graph_api/test_graph_builder_optim.py
new file mode 100644
index 0000000..5ec827d
--- /dev/null
+++ b/_unittests/ut_graph_api/test_graph_builder_optim.py
@@ -0,0 +1,38 @@
+import os
+import unittest
+import onnx
+from onnx.inliner import inline_local_functions
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.graph_api.graph_builder import GraphBuilder
+
+
+class TestGraphBuilderOptim(ExtTestCase):
+ def test_wcheck_afiles(self):
+ import onnxruntime
+
+ data = os.path.join(os.path.dirname(__file__), "data")
+ filename = [f for f in os.listdir(data) if f.endswith(".onnx")]
+ for f in filename:
+ with self.subTest(f=f):
+ onx = onnx.load(os.path.join(data, f))
+ sess = onnxruntime.InferenceSession(
+ os.path.join(data, f), providers=["CPUExecutionProvider"]
+ )
+ assert sess
+ onxi = inline_local_functions(onx)
+ sess = onnxruntime.InferenceSession(
+ onxi.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ assert sess
+ g = GraphBuilder(onxi)
+ g.optimize(check_order=True)
+ g.check_order()
+ onx2 = g.to_onnx()
+ sess2 = onnxruntime.InferenceSession(
+ onx2.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ assert sess2
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py
index b0c1cbc..91f4dd4 100644
--- a/_unittests/ut_light_api/test_backend_export.py
+++ b/_unittests/ut_light_api/test_backend_export.py
@@ -1,9 +1,11 @@
+import sys
import unittest
from typing import Any, Dict, List, Optional
from difflib import unified_diff
import packaging.version as pv
import numpy
from numpy.testing import assert_allclose
+from onnx.defs import onnx_opset_version
import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
@@ -17,16 +19,19 @@
make_opsetid,
make_tensor_value_info,
)
+from onnx.reference.op_run import to_array_extended
from onnx.numpy_helper import from_array, to_array
from onnx.backend.base import Device, DeviceType
from onnx_array_api.reference import ExtendedReferenceEvaluator
-from onnx_array_api.light_api import translate
+from onnx_array_api.translate_api.make_helper import make_node_extended
+from onnx_array_api.translate_api import translate
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
+verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0
+
class ReferenceImplementationError(RuntimeError):
"Fails, export cannot be compared."
- pass
class ExportWrapper:
@@ -34,7 +39,7 @@ class ExportWrapper:
def __init__(self, model):
self.model = model
- self.expected_sess = ExtendedReferenceEvaluator(self.model)
+ self.expected_sess = ExtendedReferenceEvaluator(self.model, verbose=verbosity)
@property
def input_names(self):
@@ -59,7 +64,8 @@ def run(
expected = self.expected_sess.run(names, feeds)
except (RuntimeError, AssertionError, TypeError, KeyError) as e:
raise ReferenceImplementationError(
- f"ReferenceImplementation fails with {onnx_simple_text_plot(self.model)}"
+ f"ReferenceImplementation fails with "
+ f"{onnx_simple_text_plot(self.model)}"
f"\n--RAW--\n{self.model}"
) from e
@@ -80,11 +86,12 @@ def run(
new_code = "\n".join(
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
)
- raise AssertionError(f"ERROR {e}\n{new_code}")
+ raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
locs = {
"np": numpy,
"to_array": to_array,
+ "to_array_extended": to_array_extended,
"from_array": from_array,
"TensorProto": TensorProto,
"make_function": make_function,
@@ -92,6 +99,7 @@ def run(
"make_model": make_model,
"make_graph": make_graph,
"make_node": make_node,
+ "make_node_extended": make_node_extended,
"make_tensor_value_info": make_tensor_value_info,
}
globs = locs.copy()
@@ -105,7 +113,7 @@ def run(
f"Unable to executed code for api {api!r}\n{new_code}"
) from e
export_model = locs["model"]
- ref = ExtendedReferenceEvaluator(export_model)
+ ref = ExtendedReferenceEvaluator(export_model, verbose=verbosity)
try:
got = ref.run(names, feeds)
except (TypeError, AttributeError) as e:
@@ -147,7 +155,8 @@ def run(
):
if a.tolist() != b.tolist():
raise AssertionError(
- f"Text discrepancies for api {api!r} with a.dtype={a.dtype} "
+ f"Text discrepancies for api {api!r} "
+ f"with a.dtype={a.dtype} "
f"and b.dtype={b.dtype}"
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
@@ -235,7 +244,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
# The following tests are too slow with the reference implementation (Conv).
backend_test.exclude(
- "(FLOAT8|BFLOAT16|_opt_|_3d_|_momentum_|_4d_"
+ "(FLOAT8|BFLOAT16|INT4|_opt_|_3d_|_momentum_|_4d_|int4"
"|test_adagrad"
"|test_adam"
"|test_ai_onnx_ml_"
@@ -263,9 +272,27 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
"|test_squeezenet"
"|test_vgg19"
"|test_zfnet512"
+ "|test_range_float_type_positive_delta_expanded"
+ "|test_range_int32_type_negative_delta_expanded"
")"
)
+if onnx_opset_version() < 22:
+ backend_test.exclude(
+ "("
+ "test_dft_inverse_cpu"
+ "|test_dft_inverse_opset19_cpu"
+ "|test_lppool_1d_default_cpu"
+ "|test_lppool_2d_default_cpu"
+ "|test_lppool_2d_dilations_cpu"
+ "|test_lppool_2d_pads_cpu"
+ "|test_lppool_2d_same_lower_cpu"
+ "|test_lppool_2d_same_upper_cpu"
+ "|test_lppool_2d_strides_cpu"
+ "|test_lppool_3d_default_cpu"
+ ")"
+ )
+
if pv.Version(onnx_version) < pv.Version("1.16.0"):
backend_test.exclude("(test_strnorm|test_range_)")
diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py
index f6ae051..f936cc1 100644
--- a/_unittests/ut_light_api/test_light_api.py
+++ b/_unittests/ut_light_api/test_light_api.py
@@ -2,7 +2,7 @@
import unittest
from typing import Callable, Optional
import numpy as np
-from onnx import GraphProto, ModelProto
+from onnx import GraphProto, ModelProto, TensorProto
from onnx.defs import (
get_all_schemas_with_history,
onnx_opset_version,
@@ -211,7 +211,7 @@ def test_neg(self):
self.assertIsInstance(v, Var)
self.assertEqual(["X"], v.parent.input_names)
s = str(v)
- self.assertEqual("X:FLOAT", s)
+ self.assertEqual("X:FLOAT:[]", s)
onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
self.assertIsInstance(onx, ModelProto)
ref = ReferenceEvaluator(onx)
@@ -484,7 +484,7 @@ def g(self):
def ah(self):
return True
- setattr(A, "h", ah)
+ setattr(A, "h", ah) # noqa: B010
self.assertTrue(A().h())
self.assertIn("(self)", str(inspect.signature(A.h)))
@@ -510,7 +510,63 @@ def ah(self):
expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1))
self.assertEqualArray(expected, got)
+ def test_input_shape(self):
+ kernel = (np.arange(9) + 1).reshape(3, 3).astype(np.float32)
+ model = (
+ start()
+ .vin("X", shape=[None, None])
+ .cst(kernel[np.newaxis, np.newaxis, ...])
+ .rename("W")
+ .bring("X", "W")
+ .Conv(pads=[1, 1, 1, 1])
+ .rename("Y")
+ .vout(shape=[])
+ .to_onnx()
+ )
+ i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
+ self.assertNotIn("shape{}", i)
+
+ def test_constant_of_shape(self):
+ onx = (
+ start()
+ .vin("X", TensorProto.INT64, shape=[None, None])
+ .ConstantOfShape()
+ .vout(shape=[])
+ .to_onnx()
+ )
+ ref = ReferenceEvaluator(onx)
+ got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
+ self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
+
+ def test_constant_of_shape_value(self):
+ onx = (
+ start()
+ .vin("X", TensorProto.INT64, shape=[None, None])
+ .ConstantOfShape(value=np.array([1], dtype=np.float32))
+ .vout(shape=[])
+ .to_onnx()
+ )
+ ref = ReferenceEvaluator(onx)
+ got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
+ self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)
+
+ def test_slice(self):
+ onx = (
+ start(opset=18, ir_version=9)
+ .cst(np.array([1], dtype=np.int64), name="one")
+ .cst(np.array([2], dtype=np.int64), name="two")
+ .vin("X", TensorProto.INT64, shape=[None, None])
+ .ConstantOfShape(value=np.array([1], dtype=np.float32))
+ .rename("CX")
+ .bring("CX", "one", "two", "one")
+ .Slice()
+ .vout(shape=[])
+ .to_onnx()
+ )
+ ref = ReferenceEvaluator(onx)
+ got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
+ self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)
+
if __name__ == "__main__":
- TestLightApi().test_domain()
unittest.main(verbosity=2)
diff --git a/_unittests/ut_npx/test_npx.py b/_unittests/ut_npx/test_npx.py
index 50e319a..873665d 100644
--- a/_unittests/ut_npx/test_npx.py
+++ b/_unittests/ut_npx/test_npx.py
@@ -208,7 +208,7 @@ def local1(
return x
def local2(
- x: TensorType[ElemType.floats, "T"]
+ x: TensorType[ElemType.floats, "T"],
) -> TensorType[ElemType.floats, "T"]:
return x
diff --git a/_unittests/ut_npx/test_sklearn_array_api.py b/_unittests/ut_npx/test_sklearn_array_api.py
index 083c009..9c0d56f 100644
--- a/_unittests/ut_npx/test_sklearn_array_api.py
+++ b/_unittests/ut_npx/test_sklearn_array_api.py
@@ -17,6 +17,7 @@ class TestSklearnArrayAPI(ExtTestCase):
reason="reshape ArrayAPI not followed",
)
@ignore_warnings(DeprecationWarning)
+ @unittest.skip("not maintained")
def test_sklearn_array_api_linear_discriminant(self):
X = np.array(
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float64
@@ -39,6 +40,7 @@ def test_sklearn_array_api_linear_discriminant(self):
reason="reshape ArrayAPI not followed",
)
@ignore_warnings(DeprecationWarning)
+ @unittest.skip("not maintained")
def test_sklearn_array_api_linear_discriminant_float32(self):
X = np.array(
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32
diff --git a/_unittests/ut_ort/test_ort_profile.py b/_unittests/ut_ort/test_ort_profile.py
index e868860..6e139cb 100644
--- a/_unittests/ut_ort/test_ort_profile.py
+++ b/_unittests/ut_ort/test_ort_profile.py
@@ -57,8 +57,6 @@ def myloss(x, y):
prof = ort_profile(optimized, feeds)
events = {
"kernel_time",
- "fence_before",
- "fence_after",
"SequentialExecutor::Execute",
"model_run",
"model_loading_array",
diff --git a/_unittests/ut_ort/test_sklearn_array_api_ort.py b/_unittests/ut_ort/test_sklearn_array_api_ort.py
index 296a9b0..f50fce1 100644
--- a/_unittests/ut_ort/test_sklearn_array_api_ort.py
+++ b/_unittests/ut_ort/test_sklearn_array_api_ort.py
@@ -17,6 +17,7 @@ class TestSklearnArrayAPIOrt(ExtTestCase):
reason="reshape ArrayAPI not followed",
)
@skipif_ci_windows("Unstable on Windows.")
+ @unittest.skip("discontinued")
def test_sklearn_array_api_linear_discriminant_ort(self):
X = np.array(
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float64
@@ -40,6 +41,7 @@ def test_sklearn_array_api_linear_discriminant_ort(self):
reason="reshape ArrayAPI not followed",
)
@skipif_ci_windows("Unstable on Windows.")
+ @unittest.skip("discontinued")
def test_sklearn_array_api_linear_discriminant_ort_float32(self):
X = np.array(
[[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], dtype=np.float32
diff --git a/_unittests/ut_plotting/test_dot_plot.py b/_unittests/ut_plotting/test_dot_plot.py
index 5c03746..4c8c4dd 100644
--- a/_unittests/ut_plotting/test_dot_plot.py
+++ b/_unittests/ut_plotting/test_dot_plot.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
import os
import unittest
diff --git a/_unittests/ut_plotting/test_graphviz.py b/_unittests/ut_plotting/test_graphviz.py
new file mode 100644
index 0000000..420779e
--- /dev/null
+++ b/_unittests/ut_plotting/test_graphviz.py
@@ -0,0 +1,51 @@
+import os
+import unittest
+import onnx.parser
+from onnx_array_api.ext_test_case import (
+ ExtTestCase,
+ skipif_ci_windows,
+ skipif_ci_apple,
+)
+from onnx_array_api.plotting.dot_plot import to_dot
+from onnx_array_api.plotting.graphviz_helper import draw_graph_graphviz, plot_dot
+
+
+class TestGraphviz(ExtTestCase):
+ @classmethod
+ def _get_graph(cls):
+ return onnx.parser.parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, x)
+ }"""
+ )
+
+ @skipif_ci_windows("graphviz not installed")
+ @skipif_ci_apple("graphviz not installed")
+ def test_draw_graph_graphviz(self):
+ fout = "test_draw_graph_graphviz.png"
+ dot = to_dot(self._get_graph())
+ draw_graph_graphviz(dot, image=fout)
+ self.assertExists(os.path.exists(fout))
+
+ @skipif_ci_windows("graphviz not installed")
+ @skipif_ci_apple("graphviz not installed")
+ def test_draw_graph_graphviz_proto(self):
+ fout = "test_draw_graph_graphviz_proto.png"
+ dot = self._get_graph()
+ draw_graph_graphviz(dot, image=fout)
+ self.assertExists(os.path.exists(fout))
+
+ @skipif_ci_windows("graphviz not installed")
+ @skipif_ci_apple("graphviz not installed")
+ def test_plot_dot(self):
+ dot = to_dot(self._get_graph())
+ ax = plot_dot(dot)
+ ax.get_figure().savefig("test_plot_dot.png")
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/_unittests/ut_plotting/test_text_plot.py b/_unittests/ut_plotting/test_text_plot.py
index 963b5cb..5844ff0 100644
--- a/_unittests/ut_plotting/test_text_plot.py
+++ b/_unittests/ut_plotting/test_text_plot.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
import os
import textwrap
import unittest
@@ -95,6 +94,7 @@ def test_onnx_text_plot_tree_cls_2(self):
+f 0:1 1:0 2:0
"""
).strip(" \n\r")
+ res = res.replace("np.float32(", "").replace(")", "")
self.assertEqual(expected, res.strip(" \n\r"))
@ignore_warnings((UserWarning, FutureWarning))
diff --git a/_unittests/ut_reference/test_backend_extended_reference_evaluator.py b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py
index b35fb3c..fbf12b7 100644
--- a/_unittests/ut_reference/test_backend_extended_reference_evaluator.py
+++ b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py
@@ -149,7 +149,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
"|test_scan_sum)"
)
-if onnx_opset_version() < 21:
+if onnx_opset_version() < 200:
# The following tests are using types not supported by NumPy.
# They could be if method to_array is extended to support custom
# types the same as the reference implementation does
@@ -164,8 +164,10 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
"|test_cast_no_saturate_"
"|_to_FLOAT8"
"|_FLOAT8"
+ "|INT4"
"|test_quantizelinear_e4m3fn"
"|test_quantizelinear_e5m2"
+ "|test_scatter_with"
")"
)
@@ -215,6 +217,25 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
# The following tests fail due to a type mismatch.
backend_test.exclude("(test_eyelike_without_dtype)")
+if onnx_opset_version() < 22:
+ backend_test.exclude(
+ "("
+ "test_adagrad_cpu"
+ "|test_adagrad_multiple_cpu"
+ "|test_dft_inverse_cpu"
+ "|test_dft_inverse_opset19_cpu"
+ "|test_lppool_1d_default_cpu"
+ "|test_lppool_2d_default_cpu"
+ "|test_lppool_2d_dilations_cpu"
+ "|test_lppool_2d_pads_cpu"
+ "|test_lppool_2d_same_lower_cpu"
+ "|test_lppool_2d_same_upper_cpu"
+ "|test_lppool_2d_strides_cpu"
+ "|test_lppool_3d_default_cpu"
+ ")"
+ )
+
+
# The following tests fail due to discrepancies (small but still higher than 1e-7).
backend_test.exclude("test_adam_multiple") # 1e-2
diff --git a/_unittests/ut_reference/test_evaluator_yield.py b/_unittests/ut_reference/test_evaluator_yield.py
new file mode 100644
index 0000000..605c1f8
--- /dev/null
+++ b/_unittests/ut_reference/test_evaluator_yield.py
@@ -0,0 +1,554 @@
+import unittest
+import numpy as np
+from onnx import TensorProto
+from onnx.checker import check_model
+from onnx.helper import (
+ make_function,
+ make_graph,
+ make_model,
+ make_node,
+ make_opsetid,
+ make_tensor_value_info,
+)
+from onnx.numpy_helper import from_array
+from onnx.parser import parse_model
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.reference import (
+ YieldEvaluator,
+ ResultType,
+ DistanceExecution,
+ ResultExecution,
+ compare_onnx_execution,
+)
+from onnx_array_api.reference.evaluator_yield import make_summary
+
+
+class TestArrayTensor(ExtTestCase):
+ def test_make_summary(self):
+ a = np.arange(12).reshape(3, 4)
+ v = make_summary(a)
+ self.assertEqual(v, "DMVE")
+ a = np.arange(12)
+ v = make_summary(a)
+ self.assertEqual(v, "DMVE")
+ a = np.arange(12).astype(np.float32)
+ v = make_summary(a)
+ self.assertEqual(v, "DMVE")
+ a = np.arange(13)
+ a[-1] = 0
+ v = make_summary(a)
+ self.assertEqual(v, "GWMA")
+
+ def test_evaluator_yield(self):
+ new_domain = "custom_domain"
+ opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)]
+
+ node1 = make_node("MatMul", ["X", "A"], ["XA"])
+ node2 = make_node("Add", ["XA", "B"], ["Y"])
+
+ linear_regression = make_function(
+ new_domain,
+ "LinearRegression",
+ ["X", "A", "B"],
+ ["Y"],
+ [node1, node2],
+ opset_imports,
+ [],
+ )
+
+ X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
+ A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
+ B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
+ Y = make_tensor_value_info("Y", TensorProto.FLOAT, None)
+
+ graph = make_graph(
+ [
+ make_node(
+ "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
+ ),
+ make_node("Abs", ["Y1"], ["Y"]),
+ ],
+ "example",
+ [X, A, B],
+ [Y],
+ )
+
+ onnx_model = make_model(
+ graph, opset_imports=opset_imports, functions=[linear_regression]
+ )
+
+ cst = np.arange(4).reshape((-1, 2)).astype(np.float32)
+ yield_eval = YieldEvaluator(onnx_model)
+ results = list(
+ yield_eval.enumerate_results(None, {"A": cst, "B": cst, "X": cst})
+ )
+ expected = [
+ (
+ ResultType.INPUT,
+ "A",
+ np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32),
+ None,
+ ),
+ (
+ ResultType.INPUT,
+ "B",
+ np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32),
+ None,
+ ),
+ (
+ ResultType.INPUT,
+ "X",
+ np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32),
+ None,
+ ),
+ (
+ ResultType.RESULT,
+ "Y1",
+ np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32),
+ "LinearRegression",
+ ),
+ (
+ ResultType.RESULT,
+ "Y",
+ np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32),
+ "Abs",
+ ),
+ (
+ ResultType.OUTPUT,
+ "Y",
+ np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32),
+ None,
+ ),
+ ]
+ self.assertEqual(len(expected), len(results))
+ for a, b in zip(expected, results):
+ self.assertEqual(len(a), len(b))
+ self.assertEqual(a[0], b[0])
+ self.assertEqual(a[1], b[1])
+ self.assertEqual(a[2].tolist(), b[2].tolist())
+ self.assertEqual(a[3], b[3])
+
+ def test_evaluator_yield_summary(self):
+ new_domain = "custom_domain"
+ opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)]
+
+ node1 = make_node("MatMul", ["X", "A"], ["XA"])
+ node2 = make_node("Add", ["XA", "B"], ["Y"])
+
+ linear_regression = make_function(
+ new_domain,
+ "LinearRegression",
+ ["X", "A", "B"],
+ ["Y"],
+ [node1, node2],
+ opset_imports,
+ [],
+ )
+
+ X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
+ A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
+ B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
+ Y = make_tensor_value_info("Y", TensorProto.FLOAT, None)
+
+ graph = make_graph(
+ [
+ make_node(
+ "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
+ ),
+ make_node("Abs", ["Y1"], ["Y"]),
+ ],
+ "example",
+ [X, A, B],
+ [Y],
+ )
+
+ onnx_model = make_model(
+ graph, opset_imports=opset_imports, functions=[linear_regression]
+ )
+
+ cst = np.arange(4).reshape((-1, 2)).astype(np.float32)
+ yield_eval = YieldEvaluator(onnx_model)
+ results = list(
+ yield_eval.enumerate_summarized(None, {"A": cst, "B": cst, "X": cst})
+ )
+ expected = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+ self.assertEqual(len(expected), len(results))
+ for a, b in zip(expected, results):
+ self.assertEqual(len(a), len(b))
+ self.assertEqual(a[0], b[0])
+ self.assertEqual(a[1], b[1])
+ self.assertEqual(a[2], b[2])
+ self.assertEqual(a[3], b[3])
+ self.assertEqual(a[4], b[4])
+ self.assertEqual(a[5], b[5])
+
+ def test_distance_pair(self):
+ el1 = (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None)
+ el2 = el1
+ dc = DistanceExecution()
+ self.assertEqual(dc.distance_pair(el1, el2), 0)
+ el2 = (ResultType.INPUT, np.dtype("float16"), (2, 2), "ABCD", None)
+ self.assertEqual(dc.distance_pair(el1, el2), 2)
+ el2 = (ResultType.OUTPUT, np.dtype("float16"), (2, 2, 4), "GBCD", "Abs")
+ self.assertEqual(dc.distance_pair(el1, el2), 1130)
+ el2 = (ResultType.OUTPUT, np.dtype("float16"), (2, 3), "GBCD", "Abs")
+ self.assertEqual(dc.distance_pair(el1, el2), 1021)
+
+ def test_distance_sequence_0(self):
+ expected = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+
+ dc = DistanceExecution()
+ d, align = dc.distance_sequence(expected, expected)
+ self.assertEqual(d, 0)
+ self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)])
+
+ def test_distance_sequence_ins(self):
+ s1 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+ s2 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+
+ dc = DistanceExecution()
+ d, align = dc.distance_sequence(s1, s2)
+ self.assertEqual(d, dc.insert_cost)
+ self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (5, 4)])
+ d, align = dc.distance_sequence(s2, s1)
+ self.assertEqual(d, dc.insert_cost)
+ self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (3, 4), (4, 5)])
+
+ def test_distance_sequence_equal(self):
+ s1 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+ s2 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Z"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+
+ dc = DistanceExecution()
+ d, align = dc.distance_sequence(s1, s2)
+ self.assertEqual(d, 0)
+ self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)])
+
+ def test_distance_sequence_diff(self):
+ s1 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+ s2 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIP", "Abs", "Z"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+
+ dc = DistanceExecution()
+ d, align = dc.distance_sequence(s1, s2)
+ self.assertEqual(d, 1)
+ self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)])
+
+ def test_distance_sequence_diff2(self):
+ s1 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+ s2 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 3), "CEIP", "Abs", "Z"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIP", None, "Y"),
+ ]
+
+ dc = DistanceExecution()
+ d, align = dc.distance_sequence(s1, s2)
+ self.assertEqual(d, 5)
+ self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)])
+
+ def test_distance_sequence_str(self):
+ s1 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 3), "ABCD", None, "X"),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Exp", "H"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"),
+ ]
+ s2 = [
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"),
+ (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"),
+ (
+ ResultType.RESULT,
+ np.dtype("float32"),
+ (2, 2),
+ "CEIO",
+ "LinearRegression",
+ "Y1",
+ ),
+ (ResultType.RESULT, np.dtype("float32"), (2, 3), "CEIP", "Abs", "Z"),
+ (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIP", None, "Y"),
+ ]
+ s1 = [ResultExecution(*s) for s in s1]
+ s2 = [ResultExecution(*s) for s in s2]
+
+ dc = DistanceExecution()
+ d, align = dc.distance_sequence(s1, s2)
+ self.assertEqual(d, 1008)
+ self.assertEqual(
+ align, [(0, 0), (1, 1), (2, 2), (3, 2), (4, 3), (5, 4), (6, 5)]
+ )
+ text = dc.to_str(s1, s2, align)
+ self.assertIn("OUTPUT", text)
+ expected = """
+ 001=|INPUTfloat322:2x2ABCDA|INPUTfloat322:2x2ABCDA
+ 002=|INPUTfloat322:2x2ABCDB|INPUTfloat322:2x2ABCDB
+ 003~|INPUTfloat322:2x3ABCDX|INPUTfloat322:2x2ABCDX
+ 004-|RESULTfloat322:2x2CEIOExpH|
+ 005=|RESULTfloat322:2x2CEIOLinearRegressioY1|RESULTfloat322:2x2CEIOLinearRegressioY1
+ 006~|RESULTfloat322:2x2CEIOAbsY|RESULTfloat322:2x3CEIPAbsZ
+ 007~|OUTPUTfloat322:2x2CEIOY|OUTPUTfloat322:2x2CEIPY
+ """.replace(
+ " ", ""
+ ).strip(
+ "\n "
+ )
+ self.maxDiff = None
+ self.assertEqual(expected, text.replace(" ", "").strip("\n"))
+
+ def test_compare_execution(self):
+ m1 = parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, x)
+ }"""
+ )
+ m2 = parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ z = Mul(x, x)
+ }"""
+ )
+ res1, res2, align, dc = compare_onnx_execution(m1, m2)
+ text = dc.to_str(res1, res2, align)
+ self.assertIn("CAAA Constant", text)
+ self.assertEqual(len(align), 5)
+
+ def test_compare_execution_discrepancies(self):
+ m1 = parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(x, x)
+ }"""
+ )
+ m2 = parse_model(
+ """
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ z = Mul(x, x)
+ }"""
+ )
+ res1, res2, align, dc = compare_onnx_execution(m1, m2, keep_tensor=True)
+ text = dc.to_str(res1, res2, align)
+ print(text)
+ self.assertIn("CAAA Constant", text)
+ self.assertIn("| a=", text)
+ self.assertIn(" r=", text)
+
+ def test_no_execution(self):
+ model = make_model(
+ make_graph(
+ [
+ make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
+ make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
+ make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
+ make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
+ make_node("Cast", ["xm2c"], ["xm2"], to=1),
+ make_node("MatMul", ["xm1", "xm2"], ["xm"]),
+ make_node("Reshape", ["xm", "shape3"], ["Z"]),
+ ],
+ "dummy",
+ [
+ make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
+ make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
+ ],
+ [make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
+ [
+ from_array(np.array([0], dtype=np.int64), name="zero"),
+ from_array(np.array([1], dtype=np.int64), name="un"),
+ from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
+ from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
+ from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
+ ],
+ )
+ )
+ check_model(model)
+ res1, res2, align, dc = compare_onnx_execution(model, model, mode="nodes")
+ text = dc.to_str(res1, res2, align)
+ self.assertIn("012 = | NODE", text)
+
+ model2 = make_model(
+ make_graph(
+ [
+ make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
+ make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
+ make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
+ make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
+ make_node("MatMul", ["xm1", "xm2c"], ["xm"]),
+ make_node("Reshape", ["xm", "shape3"], ["Z"]),
+ ],
+ "dummy",
+ [
+ make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
+ make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
+ ],
+ [make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
+ [
+ from_array(np.array([0], dtype=np.int64), name="zero"),
+ from_array(np.array([1], dtype=np.int64), name="un"),
+ from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
+ from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
+ from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
+ ],
+ )
+ )
+ check_model(model2)
+ res1, res2, align, dc = compare_onnx_execution(model, model2, mode="nodes")
+ text = dc.to_str(res1, res2, align)
+ self.assertIn("012 = | NODE", text)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/_unittests/ut_reference/test_reference_ops.py b/_unittests/ut_reference/test_reference_ops.py
new file mode 100644
index 0000000..9ae6fec
--- /dev/null
+++ b/_unittests/ut_reference/test_reference_ops.py
@@ -0,0 +1,146 @@
+import unittest
+import numpy as np
+from onnx import TensorProto
+from onnx.helper import (
+ make_graph,
+ make_model,
+ make_node,
+ make_tensor_value_info,
+ make_opsetid,
+)
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.reference import ExtendedReferenceEvaluator
+
+
+class TestReferenceOps(ExtTestCase):
+
+ def test_fused_matmul(self):
+ model = make_model(
+ make_graph(
+ [make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")],
+ "name",
+ [
+ make_tensor_value_info("X", TensorProto.FLOAT, None),
+ make_tensor_value_info("Y", TensorProto.FLOAT, None),
+ ],
+ [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
+ ),
+ opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
+ )
+ ref = ExtendedReferenceEvaluator(model)
+ a = np.arange(4).reshape(-1, 2)
+ got = ref.run(None, {"X": a, "Y": a})
+ self.assertEqualArray(a @ a, got[0])
+
+ def test_fused_matmul11(self):
+ model = make_model(
+ make_graph(
+ [
+ make_node(
+ "FusedMatMul",
+ ["X", "Y"],
+ ["Z"],
+ transA=1,
+ transB=1,
+ domain="com.microsoft",
+ )
+ ],
+ "name",
+ [
+ make_tensor_value_info("X", TensorProto.FLOAT, None),
+ make_tensor_value_info("Y", TensorProto.FLOAT, None),
+ ],
+ [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
+ ),
+ opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
+ )
+ ref = ExtendedReferenceEvaluator(model)
+ a = np.arange(4).reshape(-1, 2)
+ got = ref.run(None, {"X": a, "Y": a})
+ self.assertEqualArray(a.T @ a.T, got[0])
+
+ def test_memcpy(self):
+ model = make_model(
+ make_graph(
+ [
+ make_node("MemcpyToHost", ["X"], ["Z"]),
+ make_node("MemcpyFromHost", ["X"], ["Z"]),
+ ],
+ "name",
+ [make_tensor_value_info("X", TensorProto.FLOAT, None)],
+ [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
+ ),
+ opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
+ ir_version=9,
+ )
+ a = np.arange(4).reshape(-1, 2).astype(np.float32)
+ ref = ExtendedReferenceEvaluator(model)
+ got = ref.run(None, {"X": a})
+ self.assertEqualArray(a, got[0])
+
+ def test_quick_gelu(self):
+ from onnxruntime import InferenceSession
+
+ for alpha in [0.0, 2.0]:
+ model = make_model(
+ make_graph(
+ [
+ make_node(
+ "QuickGelu",
+ ["X"],
+ ["Z"],
+ domain="com.microsoft",
+ alpha=alpha,
+ )
+ ],
+ "name",
+ [make_tensor_value_info("X", TensorProto.FLOAT, None)],
+ [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
+ ),
+ opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
+ ir_version=9,
+ )
+ sess = InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ a = np.arange(4).reshape(-1, 2).astype(np.float32)
+ expected = sess.run(None, {"X": a})
+ ref = ExtendedReferenceEvaluator(model)
+ got = ref.run(None, {"X": a})
+ self.assertEqualArray(expected[0], got[0])
+
+ def test_scatter_elements(self):
+ model = make_model(
+ make_graph(
+ [
+ make_node(
+ "ScatterElements",
+ ["data", "indices", "updates"],
+ ["Z"],
+ axis=3,
+ reduction="add",
+ )
+ ],
+ "name",
+ [
+ make_tensor_value_info("data", TensorProto.FLOAT, None),
+ make_tensor_value_info("indices", TensorProto.INT64, None),
+ make_tensor_value_info("updates", TensorProto.FLOAT, None),
+ ],
+ [make_tensor_value_info("Z", TensorProto.FLOAT, None)],
+ ),
+ opset_imports=[make_opsetid("", 18)],
+ )
+ data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
+ indices = np.array([[[[0]]]], dtype=np.int64)
+ updates = np.array([[[[1]]]], dtype=np.float32)
+ y = np.array(
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
+ ).reshape((2, 2, 2, 2))
+ ref = ExtendedReferenceEvaluator(model)
+ got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
+ self.assertEqualArray(y, got[0])
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/_unittests/ut_tools/test_replace_constants.py b/_unittests/ut_tools/test_replace_constants.py
new file mode 100644
index 0000000..5cad1c2
--- /dev/null
+++ b/_unittests/ut_tools/test_replace_constants.py
@@ -0,0 +1,160 @@
+import unittest
+import numpy as np
+import onnx
+import onnx.helper as oh
+import onnx.numpy_helper as onh
+from onnx import TensorProto
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.reference import (
+ ExtendedReferenceEvaluator as ReferenceEvaluator,
+)
+from onnx_array_api.tools.replace_constants import (
+ replace_initializer_by_constant_of_shape,
+)
+
+
+class TestReplaceConstants(ExtTestCase):
+
+ def test_replace_initializer(self):
+ dtype = np.float32
+ value = np.random.randn(2, 100).astype(dtype)
+ A = onh.from_array(value, name="A")
+ value = np.array([1], dtype=dtype)
+ C = onh.from_array(value, name="C")
+
+ X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
+ Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
+ node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
+ node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
+ graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C])
+ model_def = oh.make_model(graph)
+
+ x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
+ oinf1 = ReferenceEvaluator(model_def)
+ y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
+ repl = replace_initializer_by_constant_of_shape(model_def)
+ node_types = {n.op_type for n in repl.graph.node}
+ self.assertIn("ConstantOfShape", node_types)
+ oinf2 = ReferenceEvaluator(repl)
+ y1[:, :] = 3.5
+ y1[0, :] = 0.5
+ y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
+ self.assertEqualArray(y1, y2)
+
+ def test_replace_constant(self):
+ dtype = np.float32
+ value = np.random.randn(2, 10).astype(dtype)
+ A = onh.from_array(value, name="A")
+ value = np.array([1], dtype=dtype)
+ C = onh.from_array(value, name="C")
+
+ X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
+ Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
+ node0 = oh.make_node("Constant", [], ["A"], value=A)
+ node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
+ node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
+ graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C])
+ model_def = oh.make_model(graph)
+
+ x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
+ oinf1 = ReferenceEvaluator(model_def)
+ y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
+ repl = replace_initializer_by_constant_of_shape(model_def, threshold=0)
+ node_types = {n.op_type for n in repl.graph.node}
+ self.assertIn("ConstantOfShape", node_types)
+ oinf2 = ReferenceEvaluator(repl)
+ y1[:, :] = 4
+ y1[0, :] = 1
+ y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
+ self.assertEqualArray(y1, y2)
+
+ def test_replace_constant_function(self):
+ dtype = np.float32
+ value = np.random.randn(2, 100).astype(dtype)
+ A = onh.from_array(value, name="A")
+ value = np.array([1], dtype=dtype)
+ C = onh.from_array(value, name="C")
+
+ X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
+ Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
+ nodeC = oh.make_node("Constant", [], ["C"], value=C)
+ node0 = oh.make_node("Constant", [], ["A"], value=A)
+ node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
+ node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
+ opset_imports = [
+ oh.make_opsetid("", onnx.defs.onnx_opset_version()),
+ oh.make_opsetid("custom", 1),
+ ]
+ fct = oh.make_function(
+ "custom",
+ "unittest",
+ ["X"],
+ ["Y"],
+ [nodeC, node0, node1, node2],
+ opset_imports,
+ )
+
+ node = oh.make_node("unittest", ["X"], ["Y"], domain="custom")
+ graph = oh.make_graph([node], "lr", [X], [Y], [C])
+ model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports)
+
+ x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
+ oinf1 = ReferenceEvaluator(model_def)
+ y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
+ repl = replace_initializer_by_constant_of_shape(model_def)
+ node_types = {n.op_type for n in repl.functions[0].node}
+ self.assertIn("ConstantOfShape", node_types)
+ oinf2 = ReferenceEvaluator(repl)
+ y1[:, :] = 3.5
+ y1[0, :] = 0.5
+ y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
+ self.assertEqualArray(y1, y2)
+
+ def test_replace_constant_graph(self):
+ value = np.array([0], dtype=np.float32)
+ zero = onh.from_array(value, name="zero")
+
+ X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
+ Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
+
+ rsum = oh.make_node("ReduceSum", ["X"], ["rsum"])
+ cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"])
+
+ then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None)
+ then_cst = onh.from_array(np.array([1] * 129).astype(np.float32))
+
+ then_const_node = oh.make_node(
+ "Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1"
+ )
+ then_body = oh.make_graph([then_const_node], "then_body", [], [then_out])
+
+ else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None)
+ else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32))
+ else_const_node = oh.make_node(
+ "Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2"
+ )
+ else_body = oh.make_graph([else_const_node], "else_body", [], [else_out])
+
+ if_node = oh.make_node(
+ "If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body
+ )
+ graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero])
+ onnx_model = oh.make_model(
+ graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())]
+ )
+ self.assertNotIn("ConstantOfShape", str(onnx_model))
+
+ x = np.ones((3, 2), dtype=np.float32)
+ oinf1 = ReferenceEvaluator(onnx_model)
+ y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
+ repl = replace_initializer_by_constant_of_shape(onnx_model)
+ self.assertIn("ConstantOfShape", str(repl))
+ oinf2 = ReferenceEvaluator(repl)
+ y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
+ y1 = y1.copy()
+ y1[:] = 0.5
+ self.assertEqualArray(y1, y2)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx b/_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx
new file mode 100644
index 0000000..8116ec3
Binary files /dev/null and b/_unittests/ut_translate_api/_data/custom_ops_type_inference_fails_0.onnx differ
diff --git a/_unittests/ut_light_api/_data/stft_inlined_batch_1.onnx b/_unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx
similarity index 100%
rename from _unittests/ut_light_api/_data/stft_inlined_batch_1.onnx
rename to _unittests/ut_translate_api/_data/stft_inlined_batch_1.onnx
diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_translate_api/test_translate.py
similarity index 93%
rename from _unittests/ut_light_api/test_translate.py
rename to _unittests/ut_translate_api/test_translate.py
index c2b2c70..98629d8 100644
--- a/_unittests/ut_light_api/test_translate.py
+++ b/_unittests/ut_translate_api/test_translate.py
@@ -5,8 +5,9 @@
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
from onnx_array_api.ext_test_case import ExtTestCase
-from onnx_array_api.light_api import start, translate, g
-from onnx_array_api.light_api.emitter import EventType
+from onnx_array_api.light_api import start, g
+from onnx_array_api.translate_api import translate
+from onnx_array_api.translate_api.base_emitter import EventType
OPSET_API = min(19, onnx_opset_version() - 1)
@@ -159,8 +160,14 @@ def test_export_if(self):
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])
code = translate(onx)
- selse = "g().cst(np.array([0], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
- sthen = "g().cst(np.array([1], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
+ selse = (
+ "g().cst(np.array([0], dtype=np.int64)).rename('Z')."
+ "bring('Z').vout(elem_type=TensorProto.FLOAT)"
+ )
+ sthen = (
+ "g().cst(np.array([1], dtype=np.int64)).rename('Z')."
+ "bring('Z').vout(elem_type=TensorProto.FLOAT)"
+ )
expected = dedent(
f"""
(
@@ -220,5 +227,4 @@ def test_aionnxml(self):
if __name__ == "__main__":
- TestTranslate().test_export_if()
unittest.main(verbosity=2)
diff --git a/_unittests/ut_translate_api/test_translate_builder.py b/_unittests/ut_translate_api/test_translate_builder.py
new file mode 100644
index 0000000..b1ad394
--- /dev/null
+++ b/_unittests/ut_translate_api/test_translate_builder.py
@@ -0,0 +1,285 @@
+import unittest
+from textwrap import dedent
+import numpy as np
+import onnx.helper as oh
+from onnx import ModelProto, TensorProto
+from onnx.checker import check_model
+from onnx.defs import onnx_opset_version
+from onnx.reference import ReferenceEvaluator
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.light_api import start
+from onnx_array_api.graph_api import GraphBuilder
+from onnx_array_api.translate_api import translate, Translater
+from onnx_array_api.translate_api.builder_emitter import BuilderEmitter
+
+
+OPSET_API = min(19, onnx_opset_version() - 1)
+
+
+class TestTranslateBuilder(ExtTestCase):
+ def setUp(self):
+ self.maxDiff = None
+
+ def test_exp(self):
+ onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Exp", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ code = translate(onx, api="builder")
+ expected = (
+ dedent(
+ """
+ def light_api(
+ op: "GraphBuilder",
+ X: "FLOAT[]",
+ ):
+ Y = op.Exp(X, outputs=['Y'])
+ op.Identity(Y, outputs=["Y"])
+ return Y
+
+ g = GraphBuilder({'': 19}, ir_version=10)
+ g.make_tensor_input("X", TensorProto.FLOAT, ())
+ light_api(g.op, "X")
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
+ model = g.to_onnx()
+ """
+ )
+ .strip("\n")
+ .replace("__SUFFIX__", ", is_dimension=False, indexed=False")
+ )
+ self.assertEqual(expected, code.strip("\n"))
+
+ def light_api(
+ op: "GraphBuilder",
+ X: "FLOAT[]", # noqa: F722
+ ):
+ Y = op.Exp(X, outputs=["Y"])
+ op.Identity(Y, outputs=["Y"])
+ return Y
+
+ g2 = GraphBuilder({"": 19})
+ g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
+ light_api(g2.op, "X")
+ g2.make_tensor_output(
+ "Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
+ )
+ onx2 = g2.to_onnx()
+
+ ref = ReferenceEvaluator(onx2)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ def test_zdoc(self):
+ onx = (
+ start(opset=19, ir_version=10)
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ code = translate(onx, api="builder")
+ expected = (
+ dedent(
+ """
+ def light_api(
+ op: "GraphBuilder",
+ X: "FLOAT[]",
+ ):
+ r = np.array([-1, 1], dtype=np.int64)
+ r0_0 = op.Reshape(X, r, outputs=['r0_0'])
+ Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y'])
+ op.Identity(Y, outputs=["Y"])
+ return Y
+
+ g = GraphBuilder({'': 19}, ir_version=10)
+ g.make_tensor_input("X", TensorProto.FLOAT, ())
+ light_api(g.op, "X")
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
+ model = g.to_onnx()
+ """
+ )
+ .strip("\n")
+ .replace("__SUFFIX__", ", is_dimension=False, indexed=False")
+ )
+ self.maxDiff = None
+ self.assertEqual(expected, code.strip("\n"))
+
+ def light_api(
+ op: "GraphBuilder",
+ X: "FLOAT[]", # noqa: F722
+ ):
+ r = np.array([-1, 1], dtype=np.int64)
+ r0_0 = op.Reshape(X, r)
+ Y = op.Transpose(r0_0, perm=[1, 0])
+ op.Identity(Y, outputs=["Y"])
+ return Y
+
+ g = GraphBuilder({"": 21})
+ X = g.make_tensor_input("X", TensorProto.FLOAT, ())
+ light_api(g.op, X)
+ g.make_tensor_output("Y", TensorProto.FLOAT, ())
+ model = g.to_onnx()
+ self.assertNotEmpty(model)
+ check_model(model)
+
+ def test_exp_f(self):
+ onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Exp", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ tr = Translater(onx, emitter=BuilderEmitter("mm"))
+ code = tr.export(as_str=True)
+
+ expected = (
+ dedent(
+ """
+ def light_api(
+ op: "GraphBuilder",
+ X: "FLOAT[]",
+ ):
+ Y = op.Exp(X, outputs=['Y'])
+ op.Identity(Y, outputs=["Y"])
+ return Y
+
+
+ def mm() -> "ModelProto":
+ g = GraphBuilder({'': 19}, ir_version=10)
+ g.make_tensor_input("X", TensorProto.FLOAT, ())
+ light_api(g.op, "X")
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
+ model = g.to_onnx()
+ return model
+
+
+ model = mm()
+ """
+ )
+ .strip("\n")
+ .replace("__SUFFIX__", ", is_dimension=False, indexed=False")
+ )
+ self.assertEqual(expected, code.strip("\n"))
+
+ def light_api(
+ op: "GraphBuilder",
+ X: "FLOAT[]", # noqa: F722
+ ):
+ Y = op.Exp(X)
+ op.Identity(Y, outputs=["Y"])
+ return Y
+
+ g2 = GraphBuilder({"": 19})
+ g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
+ light_api(g2.op, "X")
+ g2.make_tensor_output(
+ "Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
+ )
+ onx2 = g2.to_onnx()
+
+ ref = ReferenceEvaluator(onx2)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ def test_local_function(self):
+ new_domain = "custom"
+
+ linear_regression = oh.make_function(
+ new_domain,
+ "LinearRegression",
+ ["x", "a", "b"],
+ ["y"],
+ [
+ oh.make_node("MatMul", ["x", "a"], ["xa"]),
+ oh.make_node("Add", ["xa", "b"], ["y"]),
+ ],
+ [oh.make_opsetid("", 14)],
+ [],
+ )
+
+ graph = oh.make_graph(
+ [
+ oh.make_node(
+ "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
+ ),
+ oh.make_node("Abs", ["Y1"], ["Y"]),
+ ],
+ "example",
+ [
+ oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]),
+ oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
+ oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
+ ],
+ [oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
+ )
+
+ onnx_model = oh.make_model(
+ graph,
+ opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)],
+ functions=[linear_regression],
+ ir_version=10,
+ )
+ tr = Translater(onnx_model, emitter=BuilderEmitter("mm"))
+ code = tr.export(as_str=True)
+
+ expected = (
+ dedent(
+ """
+ def example(
+ op: "GraphBuilder",
+ X: "FLOAT[, ]",
+ A: "FLOAT[, ]",
+ B: "FLOAT[, ]",
+ ):
+ Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1'])
+ Y = op.Abs(Y1, outputs=['Y'])
+ op.Identity(Y, outputs=["Y"])
+ return Y
+
+
+ def make_custom_LinearRegression(g: "GraphBuilder"):
+ gr = GraphBuilder({'': 14}, as_function=True)
+ x = gr.make_tensor_input('x')
+ a = gr.make_tensor_input('a')
+ b = gr.make_tensor_input('b')
+ op = gr.op
+ xa = op.MatMul(x, a, outputs=['xa'])
+ y = op.Add(xa, b, outputs=['y'])
+ gr.make_tensor_output(y)
+ g.add_function(builder=gr)
+ return gr
+
+
+ def mm() -> "ModelProto":
+ g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10)
+ g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
+ g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
+ g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
+ example(g.op, "X", "A", "B")
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
+ make_custom_LinearRegression(g)
+ model = g.to_onnx()
+ return model
+
+
+ model = mm()
+ """
+ )
+ .strip("\n")
+ .replace("__SUFFIX__", ", is_dimension=False, indexed=False")
+ )
+ self.assertEqual(expected, code.strip("\n"))
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_translate_api/test_translate_classic.py
similarity index 60%
rename from _unittests/ut_light_api/test_translate_classic.py
rename to _unittests/ut_translate_api/test_translate_classic.py
index cb7d6a4..4f65b99 100644
--- a/_unittests/ut_light_api/test_translate_classic.py
+++ b/_unittests/ut_translate_api/test_translate_classic.py
@@ -5,6 +5,7 @@
from onnx import ModelProto, TensorProto, load
from onnx.defs import onnx_opset_version
from onnx.reference import ReferenceEvaluator
+from onnx.reference.op_run import OpRun
from onnx.helper import (
make_tensor_value_info,
make_node,
@@ -14,7 +15,8 @@
)
from onnx.checker import check_model
from onnx_array_api.ext_test_case import ExtTestCase
-from onnx_array_api.light_api import start, translate
+from onnx_array_api.light_api import start
+from onnx_array_api.translate_api import translate
OPSET_API = min(19, onnx_opset_version() - 1)
@@ -68,7 +70,7 @@ def test_exp(self):
functions = []
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
nodes.append(
- make_node(
+ make_node_extended(
'Exp',
['X'],
['Y']
@@ -144,14 +146,83 @@ def test_transpose(self):
)
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
nodes.append(
- make_node(
+ make_node_extended(
'Reshape',
['X', 'r'],
['r0_0']
)
)
nodes.append(
- make_node(
+ make_node_extended(
+ 'Transpose',
+ ['r0_0'],
+ ['Y'],
+ perm=[1, 0]
+ )
+ )
+ outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
+ graph = make_graph(
+ nodes,
+ 'light_api',
+ inputs,
+ outputs,
+ initializers,
+ sparse_initializer=sparse_initializers,
+ )
+ model = make_model(
+ graph,
+ functions=functions,
+ opset_imports=opset_imports
+ )"""
+ ).strip("\n")
+ self.maxDiff = None
+ self.assertEqual(expected, code)
+
+ def test_transpose_short(self):
+ onx = (
+ start(opset=19)
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Transpose", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(a.reshape((-1, 1)).T, got)
+
+ code = translate(onx, api="onnx-short")
+ expected = dedent(
+ """
+ opset_imports = [
+ make_opsetid('', 19),
+ ]
+ inputs = []
+ outputs = []
+ nodes = []
+ initializers = []
+ sparse_initializers = []
+ functions = []
+ initializers.append(
+ from_array(
+ np.array([-1, 1], dtype=np.int64),
+ name='r'
+ )
+ )
+ inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
+ nodes.append(
+ make_node_extended(
+ 'Reshape',
+ ['X', 'r'],
+ ['r0_0']
+ )
+ )
+ nodes.append(
+ make_node_extended(
'Transpose',
['r0_0'],
['Y'],
@@ -210,7 +281,7 @@ def test_topk_reverse(self):
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[]))
nodes.append(
- make_node(
+ make_node_extended(
'TopK',
['X', 'K'],
['Values', 'Indices'],
@@ -250,7 +321,7 @@ def test_fft(self):
new_code = "\n".join(
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
)
- raise AssertionError(f"ERROR {e}\n{new_code}")
+ raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
def test_aionnxml(self):
onx = (
@@ -264,7 +335,6 @@ def test_aionnxml(self):
.to_onnx()
)
code = translate(onx, api="onnx")
- print(code)
expected = dedent(
"""
opset_imports = [
@@ -285,14 +355,14 @@ def test_aionnxml(self):
)
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
nodes.append(
- make_node(
+ make_node_extended(
'Reshape',
['X', 'r'],
['USE']
)
)
nodes.append(
- make_node(
+ make_node_extended(
'Normalizer',
['USE'],
['Y'],
@@ -318,7 +388,115 @@ def test_aionnxml(self):
self.maxDiff = None
self.assertEqual(expected, code)
+ @classmethod
+ def _code_line(cls, code):
+ lines = code.split("\n")
+ return "\n".join(f"{i+1:03d} {line}" for i, line in enumerate(lines))
+
+ @classmethod
+ def _run(cls, code):
+ try:
+ code_compiled = compile(code, "", mode="exec")
+ except Exception as e:
+ raise AssertionError(
+ f"Compilation failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}"
+ ) from e
+
+ import onnx
+ import onnx.helper
+ import onnx.numpy_helper
+ import onnx_array_api.translate_api.make_helper
+ import onnx.reference.custom_element_types
+
+ def from_array_extended(tensor, name=None):
+ dt = tensor.dtype
+ if (
+ dt == onnx.reference.custom_element_types.float8e4m3fn
+ and dt.descr[0][0] == "e4m3fn"
+ ):
+ to = TensorProto.FLOAT8E4M3FN
+ dt_to = np.uint8
+ elif (
+ dt == onnx.reference.custom_element_types.bfloat16
+ and dt.descr[0][0] == "bfloat16"
+ ):
+ to = TensorProto.BFLOAT16
+ dt_to = np.uint16
+ else:
+ return onnx.numpy_helper.from_array(tensor, name)
+
+ t = onnx.numpy_helper.from_array(tensor.astype(dt_to), name)
+ t.data_type = to
+ return t
+
+ globs = onnx.__dict__.copy()
+ globs.update(onnx.helper.__dict__)
+ globs.update(onnx.numpy_helper.__dict__)
+ globs.update(onnx_array_api.translate_api.make_helper.__dict__)
+ globs.update(onnx.reference.custom_element_types.__dict__)
+ globs["from_array_extended"] = from_array_extended
+ locs = {}
+ try:
+ exec(code_compiled, globs, locs)
+ except Exception as e:
+ raise AssertionError(
+ f"Execution failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}"
+ ) from e
+ return globs, locs
+
+ def test_remove_nodes(self):
+ path = os.path.join(
+ os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx"
+ )
+ onx = load(path)
+ code = translate(onx, api="onnx")
+ _, locs = self._run(code)
+ self.assertIn("model", locs)
+ model = locs["model"]
+ x = np.arange(4).reshape((-1, 2)).astype(np.float32)
+ feeds = {"X": x}
+
+ class CustomGemmFloat8E4M3FN(OpRun):
+ op_domain = "onnx_extented.ortops.tutorial.cpu"
+
+ def _run(
+ self,
+ x,
+ y,
+ bias=None,
+ scale_x=None,
+ scale_y=None,
+ scale_z=None,
+ transA=False,
+ transB=False,
+ dtype=None,
+ rowMajor=None,
+ computeType=None,
+ ):
+ if scale_x is not None:
+ x = x * scale_x
+ if transA:
+ x = x.T
+ if scale_y is not None:
+ y = y * scale_y
+ if transB:
+ y = y.T
+ z = x @ y
+ if bias is not None:
+ z += bias
+ if scale_z is not None:
+ z = z / scale_z
+ return (z,)
+
+ ref = ReferenceEvaluator(onx, new_ops=[CustomGemmFloat8E4M3FN])
+ expected = ref.run(None, feeds)[0]
+ ref2 = ReferenceEvaluator(model, new_ops=[CustomGemmFloat8E4M3FN])
+ got = ref2.run(None, feeds)[0]
+ self.assertEqualArray(expected, got)
+
+ # with open("debug_test_remove_nodes.py", "w") as f:
+ # f.write(code)
+
if __name__ == "__main__":
- # TestLightApi().test_topk()
unittest.main(verbosity=2)
diff --git a/_unittests/ut_validation/test_f8.py b/_unittests/ut_validation/test_f8.py
index 80611b5..4c6517f 100644
--- a/_unittests/ut_validation/test_f8.py
+++ b/_unittests/ut_validation/test_f8.py
@@ -88,7 +88,7 @@ def test_fe5m2_to_float32_paper(self):
self.assertEqual(fe5m2_to_float32(int("11111100", 2)), -numpy.inf)
def test_fe4m3fn_to_float32_all(self):
- for i in range(0, 256):
+ for i in range(256):
a = fe4m3_to_float32_float(i)
b = fe4m3_to_float32(i)
if numpy.isnan(a):
@@ -97,7 +97,7 @@ def test_fe4m3fn_to_float32_all(self):
self.assertEqual(a, b)
def test_fe4m3fn_to_float32_all_ml_types(self):
- for i in range(0, 256):
+ for i in range(256):
a = fe4m3_to_float32_float(i)
b = fe4m3_to_float32(i)
c = new_cvt_float32_to_e4m3fn(b)
@@ -188,7 +188,7 @@ def test_search_float32_into_fe5m2_simple(self):
self.assertEqual(b1, b2)
def test_search_float32_into_fe4m3fn_equal(self):
- values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)]
+ values = [(fe4m3_to_float32_float(i), i) for i in range(256)]
values.sort()
for value, expected in values:
@@ -208,7 +208,7 @@ def test_search_float32_into_fe4m3fn_equal(self):
self.assertIn(nf, (0, 128))
def test_search_float32_into_fe5m2_equal(self):
- values = [(fe5m2_to_float32_float(i), i) for i in range(0, 256)]
+ values = [(fe5m2_to_float32_float(i), i) for i in range(256)]
values.sort()
for value, expected in values:
@@ -233,7 +233,7 @@ def test_search_float32_into_fe5m2_equal(self):
self.assertEqual(fe5m2_to_float32(nf), float(cf))
def test_search_float32_into_fe4m3fn(self):
- values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)]
+ values = [(fe4m3_to_float32_float(i), i) for i in range(256)]
values.sort()
obs = []
@@ -308,7 +308,7 @@ def test_search_float32_into_fe4m3fn(self):
)
def test_search_float32_into_fe5m2(self):
- values = [(fe5m2_to_float32_float(i), i) for i in range(0, 256)]
+ values = [(fe5m2_to_float32_float(i), i) for i in range(256)]
values.sort()
obs = []
@@ -651,7 +651,7 @@ def test_search_float32_into_fe5m2fnuz_simple(self):
self.assertEqual(expected, got)
def test_fe4m3fnuz_to_float32_all(self):
- for i in range(0, 256):
+ for i in range(256):
a = fe4m3_to_float32_float(i, uz=True)
b = fe4m3_to_float32(i, uz=True)
if numpy.isnan(a):
@@ -660,7 +660,7 @@ def test_fe4m3fnuz_to_float32_all(self):
self.assertEqual(a, b)
def test_fe5m2fnuz_to_float32_all(self):
- for i in range(0, 256):
+ for i in range(256):
a = fe5m2_to_float32_float(i, fn=True, uz=True)
b = fe5m2_to_float32(i, fn=True, uz=True)
if numpy.isnan(a):
@@ -669,7 +669,7 @@ def test_fe5m2fnuz_to_float32_all(self):
self.assertEqual(a, b)
def test_search_float32_into_fe4m3fnuz(self):
- values = [(fe4m3_to_float32_float(i, uz=True), i) for i in range(0, 256)]
+ values = [(fe4m3_to_float32_float(i, uz=True), i) for i in range(256)]
values.sort()
obs = []
@@ -715,9 +715,7 @@ def test_search_float32_into_fe4m3fnuz(self):
)
def test_search_float32_into_fe5m2fnuz(self):
- values = [
- (fe5m2_to_float32_float(i, fn=True, uz=True), i) for i in range(0, 256)
- ]
+ values = [(fe5m2_to_float32_float(i, fn=True, uz=True), i) for i in range(256)]
values.sort()
obs = []
@@ -1235,7 +1233,7 @@ def test_nan(self):
expected,
)
]
- for i in range(0, 23):
+ for i in range(23):
v = 0x7F800000 | (1 << i)
f = numpy.uint32(v).view(numpy.float32)
values.append((i, v, f, expected))
diff --git a/_unittests/ut_xrun_doc/test_command_lines1.py b/_unittests/ut_xrun_doc/test_command_lines1.py
index 8aa17ee..0503f55 100644
--- a/_unittests/ut_xrun_doc/test_command_lines1.py
+++ b/_unittests/ut_xrun_doc/test_command_lines1.py
@@ -14,7 +14,9 @@
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api._command_lines_parser import (
get_main_parser,
+ get_parser_compare,
get_parser_translate,
+ get_parser_replace,
main,
)
@@ -34,6 +36,13 @@ def test_parser_translate(self):
text = st.getvalue()
self.assertIn("model", text)
+ def test_parser_replace(self):
+ st = StringIO()
+ with redirect_stdout(st):
+ get_parser_replace().print_help()
+ text = st.getvalue()
+ self.assertIn("model", text)
+
def test_command_translate(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
@@ -70,6 +79,42 @@ def test_command_translate(self):
code = st.getvalue()
self.assertIn("start(opset=", code)
+ def test_parser_compare(self):
+ st = StringIO()
+ with redirect_stdout(st):
+ get_parser_compare().print_help()
+ text = st.getvalue()
+ self.assertIn("model1", text)
+
+ def test_command_compare(self):
+ X = make_tensor_value_info("X", TensorProto.FLOAT, [5, 6])
+ Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
+ Z = make_tensor_value_info("Z", TensorProto.FLOAT, [5, 6])
+ graph = make_graph(
+ [
+ make_node("Add", ["X", "Y"], ["res"]),
+ make_node("Cos", ["res"], ["Z"]),
+ ],
+ "g",
+ [X, Y],
+ [Z],
+ )
+ onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)])
+
+ with tempfile.TemporaryDirectory() as root:
+ model_file = os.path.join(root, "model.onnx")
+ with open(model_file, "wb") as f:
+ f.write(onnx_model.SerializeToString())
+
+ args = ["compare", "-m1", model_file, "-m2", model_file, "-v", "1"]
+ st = StringIO()
+ with redirect_stdout(st):
+ main(args)
+
+ code = st.getvalue()
+ self.assertIn("[compare_onnx_execution]", code)
+ self.assertIn("ADFF", code)
+
if __name__ == "__main__":
unittest.main(verbosity=2)
diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py
index 170e82b..6f6a5d1 100644
--- a/_unittests/ut_xrun_doc/test_documentation_examples.py
+++ b/_unittests/ut_xrun_doc/test_documentation_examples.py
@@ -49,7 +49,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
if verbose:
print(f"failed: {name!r} due to missing dot.")
return 0
- raise AssertionError(
+ raise AssertionError( # noqa: B904
"Example '{}' (cmd: {} - exec_prefix='{}') "
"failed due to\n{}"
"".format(name, cmds, sys.exec_prefix, st)
@@ -65,14 +65,15 @@ def add_test_methods(cls):
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "examples"))
found = os.listdir(fold)
for name in found:
- if name.startswith("plot_") and name.endswith(".py"):
- short_name = os.path.split(os.path.splitext(name)[0])[-1]
+ if not name.startswith("plot_") or not name.endswith(".py"):
+ continue
+ short_name = os.path.split(os.path.splitext(name)[0])[-1]
- def _test_(self, name=name):
- res = self.run_test(fold, name, verbose=VERBOSE)
- self.assertTrue(res)
+ def _test_(self, name=name):
+ res = self.run_test(fold, name, verbose=VERBOSE)
+ self.assertTrue(res)
- setattr(cls, f"test_{short_name}", _test_)
+ setattr(cls, f"test_{short_name}", _test_)
TestDocumentationExamples.add_test_methods()
diff --git a/_unittests/ut_xrun_doc/test_profiling.py b/_unittests/ut_xrun_doc/test_profiling.py
index e6c7e69..a7d3ce1 100644
--- a/_unittests/ut_xrun_doc/test_profiling.py
+++ b/_unittests/ut_xrun_doc/test_profiling.py
@@ -1,6 +1,7 @@
"""
@brief test tree node (time=5s)
"""
+
import os
import sys
import time
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 907bb9f..e9b3859 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -4,8 +4,8 @@ jobs:
vmImage: 'ubuntu-latest'
strategy:
matrix:
- Python311-Linux:
- python.version: '3.11'
+ Python312-Linux:
+ python.version: '3.12'
maxParallel: 3
steps:
@@ -24,7 +24,7 @@ jobs:
- script: pip install -r requirements-dev.txt
displayName: 'Install Requirements dev'
- script: |
- ruff .
+ ruff check .
displayName: 'Ruff'
- script: |
black --diff .
@@ -35,6 +35,9 @@ jobs:
- script: |
python -m pip install . -v -v -v
displayName: 'install wheel'
+ - script: |
+ python -m pip freeze
+ displayName: 'pip freeze'
- script: |
python -m pytest
displayName: 'Runs Unit Tests'
@@ -48,8 +51,8 @@ jobs:
vmImage: 'ubuntu-latest'
strategy:
matrix:
- Python311-Linux:
- python.version: '3.11'
+ Python312-Linux:
+ python.version: '3.12'
maxParallel: 3
steps:
@@ -78,11 +81,14 @@ jobs:
- script: pip install onnxmltools --no-deps
displayName: 'Install onnxmltools'
- script: |
- ruff .
+ ruff check .
displayName: 'Ruff'
- script: |
black --diff .
displayName: 'Black'
+ - script: |
+ python -m pip freeze
+ displayName: 'pip freeze'
- script: |
python -m pytest
displayName: 'Runs Unit Tests'
@@ -125,16 +131,19 @@ jobs:
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
cd array-api-tests
displayName: 'Set API'
+ - script: |
+ python -m pip freeze
+ displayName: 'pip freeze'
- script: |
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
cd array-api-tests
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt --hypothesis-explain
displayName: "numpy test_creation_functions.py"
- - script: |
- export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
- cd array-api-tests
- python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt --hypothesis-explain
- displayName: "ort test_creation_functions.py"
+ # - script: |
+ # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
+ # cd array-api-tests
+ # python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt --hypothesis-explain
+ # displayName: "ort test_creation_functions.py"
#- script: |
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
# cd array-api-tests
@@ -146,8 +155,8 @@ jobs:
vmImage: 'ubuntu-latest'
strategy:
matrix:
- Python311-Linux:
- python.version: '3.11'
+ Python312-Linux:
+ python.version: '3.12'
maxParallel: 3
steps:
@@ -172,11 +181,14 @@ jobs:
- script: pip install onnxmltools --no-deps
displayName: 'Install onnxmltools'
- script: |
- ruff .
+ ruff check .
displayName: 'Ruff'
- script: |
black --diff .
displayName: 'Black'
+ - script: |
+ python -m pip freeze
+ displayName: 'pip freeze'
- script: |
python -m pytest --cov
displayName: 'Runs Unit Tests'
@@ -196,8 +208,8 @@ jobs:
vmImage: 'windows-latest'
strategy:
matrix:
- Python311-Windows:
- python.version: '3.11'
+ Python312-Windows:
+ python.version: '3.12'
maxParallel: 3
steps:
@@ -213,6 +225,9 @@ jobs:
displayName: 'Install Requirements dev'
- script: pip install onnxmltools --no-deps
displayName: 'Install onnxmltools'
+ - script: |
+ python -m pip freeze
+ displayName: 'pip freeze'
- script: |
python -m pytest -v
displayName: 'Runs Unit Tests'
@@ -223,47 +238,3 @@ jobs:
inputs:
artifactName: 'wheel-windows-$(python.version)'
targetPath: 'dist'
-
-- job: 'TestMac'
- pool:
- vmImage: 'macOS-latest'
- strategy:
- matrix:
- Python311-Mac:
- python.version: '3.11'
- maxParallel: 3
-
- steps:
- - task: UsePythonVersion@0
- inputs:
- versionSpec: '$(python.version)'
- architecture: 'x64'
- - script: gcc --version
- displayName: 'gcc version'
- #- script: brew upgrade
- # displayName: 'brew upgrade'
- #- script: brew update
- # displayName: 'brew update'
- - script: export
- displayName: 'export'
- - script: gcc --version
- displayName: 'gcc version'
- - script: python -m pip install --upgrade pip setuptools wheel
- displayName: 'Install tools'
- - script: pip install -r requirements.txt
- displayName: 'Install Requirements'
- - script: pip install -r requirements-dev.txt
- displayName: 'Install Requirements dev'
- - script: pip install onnxmltools --no-deps
- displayName: 'Install onnxmltools'
- - script: |
- python -m pytest
- displayName: 'Runs Unit Tests'
- - script: |
- python -u setup.py bdist_wheel
- displayName: 'Build Package'
- - task: PublishPipelineArtifact@0
- inputs:
- artifactName: 'wheel-mac-$(python.version)'
- targetPath: 'dist'
-
diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py
index 09a2edd..98371ac 100644
--- a/onnx_array_api/__init__.py
+++ b/onnx_array_api/__init__.py
@@ -1,7 +1,6 @@
-# coding: utf-8
"""
APIs to create ONNX Graphs.
"""
-__version__ = "0.1.3"
+__version__ = "0.3.1"
__author__ = "Xavier Dupré"
diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py
index 3860f18..d1eac62 100644
--- a/onnx_array_api/_command_lines_parser.py
+++ b/onnx_array_api/_command_lines_parser.py
@@ -14,12 +14,15 @@ def get_main_parser() -> ArgumentParser:
)
parser.add_argument(
"cmd",
- choices=["translate"],
+ choices=["translate", "compare", "replace"],
help=dedent(
"""
Selects a command.
-
- 'translate' exports an onnx graph into a piece of code replicating it.
+
+ 'translate' exports an onnx graph into a piece of code replicating it,
+ 'compare' compares the execution of two onnx models,
+ 'replace' replaces constant and initliazers by ConstantOfShape
+ to make the model lighter
"""
),
)
@@ -48,7 +51,7 @@ def get_parser_translate() -> ArgumentParser:
parser.add_argument(
"-a",
"--api",
- choices=["onnx", "light"],
+ choices=["onnx", "light", "onnx-short", "builder"],
default="onnx",
help="API to choose, API from onnx package or light API.",
)
@@ -56,7 +59,7 @@ def get_parser_translate() -> ArgumentParser:
def _cmd_translate(argv: List[Any]):
- from .light_api import translate
+ from .translate_api import translate
parser = get_parser_translate()
args = parser.parse_args(argv[1:])
@@ -65,8 +68,152 @@ def _cmd_translate(argv: List[Any]):
print(code)
+def get_parser_compare() -> ArgumentParser:
+ parser = ArgumentParser(
+ prog="compare",
+ description=dedent(
+ """
+ Compares the execution of two onnx models.
+ """
+ ),
+ epilog="This is used when two models are different but "
+ "should produce the same results.",
+ )
+ parser.add_argument(
+ "-m1",
+ "--model1",
+ type=str,
+ required=True,
+ help="first onnx model",
+ )
+ parser.add_argument(
+ "-m2",
+ "--model2",
+ type=str,
+ required=True,
+ help="second onnx model",
+ )
+ parser.add_argument(
+ "-m",
+ "--mode",
+ choices=["execute", "nodes"],
+ default="execute",
+ help="compare the execution ('execute') or the nodes only ('nodes')",
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ default=0,
+ help="verbosity",
+ )
+ parser.add_argument(
+ "-c",
+ "--column-size",
+ default=60,
+ help="column size when displaying the results",
+ )
+ parser.add_argument(
+ "-d",
+ "--discrepancies",
+ default=0,
+ help="show precise discrepancies when mode is execution",
+ )
+ return parser
+
+
+def _cmd_compare(argv: List[Any]):
+ from .reference import compare_onnx_execution
+
+ parser = get_parser_compare()
+ args = parser.parse_args(argv[1:])
+ if args.verbose in ("1", 1, "True", True):
+ print(f"[compare] first model {args.model1!r}")
+ print(f"[compare] second model {args.model2!r}")
+ onx1 = onnx.load(args.model1)
+ onx2 = onnx.load(args.model2)
+ if args.verbose in ("1", 1, "True", True):
+ print(f"[compare] first model has {len(onx1.graph.node)} nodes")
+ print(f"[compare] second model has {len(onx2.graph.node)} nodes")
+ res1, res2, align, dc = compare_onnx_execution(
+ onx1,
+ onx2,
+ verbose=args.verbose,
+ mode=args.mode,
+ keep_tensor=args.discrepancies in (1, "1", "True", True),
+ )
+ text = dc.to_str(res1, res2, align, column_size=int(args.column_size))
+ print(text)
+
+
+def get_parser_replace() -> ArgumentParser:
+ parser = ArgumentParser(
+ prog="translate",
+ description=dedent(
+ """
+ Replaces constants and initializes by ConstOfShape or any other nodes
+ to make the model smaller.
+ """
+ ),
+ epilog="This is mostly used to write unit tests without adding "
+ "a big file to the repository.",
+ )
+ parser.add_argument(
+ "-m",
+ "--model",
+ type=str,
+ required=True,
+ help="onnx model to translate",
+ )
+ parser.add_argument(
+ "-o",
+ "--out",
+ type=str,
+ required=True,
+ help="output file",
+ )
+ parser.add_argument(
+ "-t",
+ "--threshold",
+ default=128,
+ help="Threshold above which every constant is replaced",
+ )
+ parser.add_argument(
+ "--type",
+ default="ConstontOfShape",
+ help="Inserts this operator type",
+ )
+ parser.add_argument(
+ "--domain",
+ default="",
+ help="Inserts this domain",
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ default=0,
+ help="verbosity",
+ )
+ return parser
+
+
+def _cmd_replace(argv: List[Any]):
+ from .tools.replace_constants import replace_initializer_by_constant_of_shape
+
+ parser = get_parser_replace()
+ args = parser.parse_args(argv[1:])
+ if args.verbose in ("1", 1, "True", True):
+ print(f"[compare] load model {args.model!r}")
+ onx = onnx.load(args.model)
+ new_onx = replace_initializer_by_constant_of_shape(
+ onx, threshold=args.threshold, op_type=args.type, domain=args.domain
+ )
+ if args.verbose in ("1", 1, "True", True):
+ print(f"[compare] save model {args.out!r}")
+ onnx.save(new_onx, args.out)
+
+
def main(argv: Optional[List[Any]] = None):
- fcts = dict(translate=_cmd_translate)
+ fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace)
if argv is None:
argv = sys.argv[1:]
@@ -75,7 +222,11 @@ def main(argv: Optional[List[Any]] = None):
parser = get_main_parser()
parser.parse_args(argv)
else:
- parsers = dict(translate=get_parser_translate)
+ parsers = dict(
+ translate=get_parser_translate,
+ compare=get_parser_compare,
+ replace=get_parser_replace,
+ )
cmd = argv[0]
if cmd not in parsers:
raise ValueError(
diff --git a/onnx_array_api/_helpers.py b/onnx_array_api/_helpers.py
index f9808ca..9331098 100644
--- a/onnx_array_api/_helpers.py
+++ b/onnx_array_api/_helpers.py
@@ -9,7 +9,7 @@ def np_dtype_to_tensor_dtype(dtype: Any):
"""
try:
dt = helper.np_dtype_to_tensor_dtype(dtype)
- except KeyError:
+ except (KeyError, ValueError):
if dtype == np.float32:
dt = TensorProto.FLOAT
elif dtype == np.float64:
@@ -40,6 +40,10 @@ def np_dtype_to_tensor_dtype(dtype: Any):
dt = TensorProto.INT64
elif dtype is float:
dt = TensorProto.DOUBLE
+ elif dtype == np.complex64:
+ dt = TensorProto.COMPLEX64
+ elif dtype == np.complex128:
+ dt = TensorProto.COMPLEX128
else:
- raise KeyError(f"Unable to guess type for dtype={dtype}.")
+ raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904
return dt
diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/annotations.py
similarity index 84%
rename from onnx_array_api/light_api/annotations.py
rename to onnx_array_api/annotations.py
index 3fe7973..c29102c 100644
--- a/onnx_array_api/light_api/annotations.py
+++ b/onnx_array_api/annotations.py
@@ -64,6 +64,8 @@ def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
np.uint64: TensorProto.UINT64,
np.bool_: TensorProto.BOOL,
np.str_: TensorProto.STRING,
+ np.complex64: TensorProto.COMPLEX64,
+ np.complex128: TensorProto.COMPLEX128,
}
@@ -81,9 +83,17 @@ def elem_type_int(elem_type: ELEMENT_TYPE) -> int:
return np_dtype_to_tensor_dtype(elem_type)
-def make_shape(shape: TensorShapeProto) -> SHAPE_TYPE:
+def _pick_dim(d, empty_dim):
+ if d.dim_value:
+ return d.dim_value
+ if d.dim_param:
+ return d.dim_param
+ return empty_dim
+
+
+def make_shape(shape: TensorShapeProto, empty_dim: Optional[Any] = None) -> SHAPE_TYPE:
"Extracts a shape from a tensor type."
- if hasattr(shape, "dims"):
- res = [(d.dim_value if d.dim_value else d.dim_param) for d in shape.dims]
+ if hasattr(shape, "dim"):
+ res = [_pick_dim(d, empty_dim=empty_dim) for i, d in enumerate(shape.dim)]
return tuple(res)
return None
diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py
index f4b3c4d..9b67b4b 100644
--- a/onnx_array_api/array_api/__init__.py
+++ b/onnx_array_api/array_api/__init__.py
@@ -47,12 +47,14 @@ def _finfo(dtype):
continue
if isinstance(v, (np.float32, np.float64, np.float16)):
d[k] = float(v)
+ elif isinstance(v, (np.complex128, np.complex64)):
+ d[k] = complex(v)
else:
d[k] = v
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
nres = type("finfo", (res.__class__,), d)
- setattr(nres, "smallest_normal", float(res.smallest_normal))
- setattr(nres, "tiny", float(res.tiny))
+ setattr(nres, "smallest_normal", float(res.smallest_normal)) # noqa: B010
+ setattr(nres, "tiny", float(res.tiny)) # noqa: B010
return nres
@@ -84,8 +86,8 @@ def _iinfo(dtype):
d[k] = v
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
nres = type("iinfo", (res.__class__,), d)
- setattr(nres, "min", int(res.min))
- setattr(nres, "max", int(res.max))
+ setattr(nres, "min", int(res.min)) # noqa: B010
+ setattr(nres, "max", int(res.max)) # noqa: B010
return nres
@@ -124,6 +126,8 @@ def _finalize_array_api(module, function_names, TEagerTensor):
module.float16 = DType(TensorProto.FLOAT16)
module.float32 = DType(TensorProto.FLOAT)
module.float64 = DType(TensorProto.DOUBLE)
+ module.complex64 = DType(TensorProto.COMPLEX64)
+ module.complex128 = DType(TensorProto.COMPLEX128)
module.int8 = DType(TensorProto.INT8)
module.int16 = DType(TensorProto.INT16)
module.int32 = DType(TensorProto.INT32)
@@ -133,10 +137,10 @@ def _finalize_array_api(module, function_names, TEagerTensor):
module.uint32 = DType(TensorProto.UINT32)
module.uint64 = DType(TensorProto.UINT64)
module.bfloat16 = DType(TensorProto.BFLOAT16)
- setattr(module, "bool", DType(TensorProto.BOOL))
- setattr(module, "str", DType(TensorProto.STRING))
- setattr(module, "finfo", _finfo)
- setattr(module, "iinfo", _iinfo)
+ setattr(module, "bool", DType(TensorProto.BOOL)) # noqa: B010
+ setattr(module, "str", DType(TensorProto.STRING)) # noqa: B010
+ setattr(module, "finfo", _finfo) # noqa: B010
+ setattr(module, "iinfo", _iinfo) # noqa: B010
if function_names is None:
function_names = supported_functions
@@ -146,7 +150,10 @@ def _finalize_array_api(module, function_names, TEagerTensor):
if f is None:
f2 = getattr(npx_functions, name, None)
if f2 is None:
- warnings.warn(f"Function {name!r} is not available in {module!r}.")
+ warnings.warn(
+ f"Function {name!r} is not available in {module!r}.",
+ stacklevel=0,
+ )
continue
f = lambda TEagerTensor, *args, _f=f2, **kwargs: _f( # noqa: E731
*args, **kwargs
diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py
index 6e8ee6d..7c486ce 100644
--- a/onnx_array_api/array_api/_onnx_common.py
+++ b/onnx_array_api/array_api/_onnx_common.py
@@ -1,11 +1,8 @@
from typing import Any, Optional
-import warnings
import numpy as np
from onnx import TensorProto
+import array_api_strict
-with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- from numpy.array_api._array_object import Array
from ..npx.npx_types import (
DType,
ElemType,
@@ -30,6 +27,9 @@
)
+Array = type(array_api_strict.ones((1,)))
+
+
# These functions with no specific code do not have to be
# implemented. They are automatically added in
# :mod:`onnx_array_api.array_api`. It needs
@@ -46,14 +46,13 @@ def asarray(
dtype: Optional[DType] = None,
order: Optional[str] = None,
like: Any = None,
+ device: Optional[str] = None,
copy: bool = False,
) -> EagerTensor:
"""
Converts anything into an array.
"""
- """
- Converts anything into an array.
- """
+ assert device is None, f"asarray not implemented yet for device={device!r}"
if order not in ("C", None):
raise NotImplementedError(f"asarray is not implemented for order={order!r}.")
if like is not None:
@@ -88,18 +87,20 @@ def asarray(
v = TEagerTensor(va)
elif isinstance(a, float):
v = TEagerTensor(np.array(a, dtype=np.float64))
+ elif isinstance(a, complex):
+ v = TEagerTensor(np.array(a, dtype=np.complex128))
elif isinstance(a, bool):
v = TEagerTensor(np.array(a, dtype=np.bool_))
elif isinstance(a, str):
v = TEagerTensor(np.array(a, dtype=np.str_))
elif isinstance(a, list):
- if all(map(lambda x: isinstance(x, bool), a)):
+ if all(isinstance(x, bool) for x in a):
v = TEagerTensor(np.array(a, dtype=np.bool_))
- elif all(map(lambda x: isinstance(x, int), a)):
+ elif all(isinstance(x, int) for x in a):
try:
cvt = np.array(a, dtype=np.int64)
except OverflowError as e:
- if all(map(lambda x: x >= 0, a)):
+ if all(x >= 0 for x in a):
cvt = np.array(a, dtype=np.uint64)
else:
raise e
@@ -108,7 +109,7 @@ def asarray(
v = TEagerTensor(np.array(a))
elif isinstance(a, np.ndarray):
v = TEagerTensor(a)
- elif isinstance(a, Array):
+ elif Array and isinstance(a, Array):
v = TEagerTensor(np.asarray(a))
else:
raise RuntimeError(f"Unexpected type {type(a)} for the first input.")
@@ -128,9 +129,7 @@ def arange(
step: EagerTensor[OptTensorType[ElemType.int64, "I", (1,)]] = None,
dtype: OptParType[DType] = None,
) -> EagerTensor[TensorType[ElemType.numerics, "T"]]:
- use_float = any(
- map(lambda x: isinstance(x, float), [start_or_stop, stop_or_step, step])
- )
+ use_float = any(isinstance(x, float) for x in [start_or_stop, stop_or_step, step])
if isinstance(start_or_stop, int):
start_or_stop = TEagerTensor(
np.array([start_or_stop], dtype=np.float64 if use_float else np.int64)
@@ -208,7 +207,7 @@ def eye(
/,
*,
k: ParType[int] = 0,
- dtype: ParType[DType] = DType(TensorProto.DOUBLE),
+ dtype: ParType[DType] = DType(TensorProto.DOUBLE), # noqa: B008
):
if isinstance(n_rows, int):
n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64))
@@ -246,7 +245,7 @@ def linspace(
dtype: OptParType[DType] = None,
endpoint: ParType[int] = 1,
) -> EagerTensor[TensorType[ElemType.numerics, "T"]]:
- use_float = any(map(lambda x: isinstance(x, float), [start, stop]))
+ use_float = any(isinstance(x, float) for x in [start, stop])
if isinstance(start, int):
start = TEagerTensor(
np.array(start, dtype=np.float64 if use_float else np.int64)
diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py
index 1068bda..d91ba1a 100644
--- a/onnx_array_api/ext_test_case.py
+++ b/onnx_array_api/ext_test_case.py
@@ -19,6 +19,10 @@ def is_windows() -> bool:
return sys.platform == "win32"
+def is_apple() -> bool:
+ return sys.platform == "darwin"
+
+
def skipif_ci_windows(msg) -> Callable:
"""
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
@@ -29,6 +33,16 @@ def skipif_ci_windows(msg) -> Callable:
return lambda x: x
+def skipif_ci_apple(msg) -> Callable:
+ """
+ Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
+ """
+ if is_apple() and is_azure():
+ msg = f"Test does not work on azure pipeline (Apple). {msg}"
+ return unittest.skip(msg)
+ return lambda x: x
+
+
def ignore_warnings(warns: List[Warning]) -> Callable:
"""
Catches warnings.
@@ -221,7 +235,7 @@ def assertRaise(self, fct: Callable, exc_type: Exception):
fct()
except exc_type as e:
if not isinstance(e, exc_type):
- raise AssertionError(f"Unexpected exception {type(e)!r}.")
+ raise AssertionError(f"Unexpected exception {type(e)!r}.") # noqa: B904
return
raise AssertionError("No exception was raised.")
@@ -230,6 +244,10 @@ def assertEmpty(self, value: Any):
return
raise AssertionError(f"value is not empty: {value!r}.")
+ def assertExists(self, name):
+ if not os.path.exists(name):
+ raise AssertionError(f"File or folder {name!r} does not exists.")
+
def assertHasAttr(self, cls: type, name: str):
if not hasattr(cls, name):
raise AssertionError(f"Class {cls} has no attribute {name!r}.")
@@ -248,7 +266,7 @@ def assertStartsWith(self, prefix: str, full: str):
@classmethod
def tearDownClass(cls):
for name, line, w in cls._warns:
- warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}")
+ warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}", stacklevel=0)
def capture(self, fct: Callable):
"""
@@ -259,9 +277,8 @@ def capture(self, fct: Callable):
"""
sout = StringIO()
serr = StringIO()
- with redirect_stdout(sout):
- with redirect_stderr(serr):
- res = fct()
+ with redirect_stdout(sout), redirect_stderr(serr):
+ res = fct()
return res, sout.getvalue(), serr.getvalue()
def relative_path(self, filename: str, *names: List[str]) -> str:
diff --git a/onnx_array_api/graph_api/__init__.py b/onnx_array_api/graph_api/__init__.py
new file mode 100644
index 0000000..15e274e
--- /dev/null
+++ b/onnx_array_api/graph_api/__init__.py
@@ -0,0 +1 @@
+from .graph_builder import GraphBuilder, NodePattern
diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py
new file mode 100644
index 0000000..5e414ed
--- /dev/null
+++ b/onnx_array_api/graph_api/graph_builder.py
@@ -0,0 +1,1024 @@
+import sys
+from functools import partial
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
+import numpy as np
+from onnx.defs import onnx_opset_version
+import onnx.helper as oh
+import onnx.numpy_helper as onh
+from onnx import (
+ AttributeProto,
+ FunctionProto,
+ GraphProto,
+ ModelProto,
+ NodeProto,
+ TensorProto,
+)
+from onnx.reference import ReferenceEvaluator
+
+T = "TENSOR"
+
+
+class OptimizationOptions:
+ def __init__(
+ self,
+ remove_unused: bool = True,
+ constant_folding: bool = False,
+ constant_size: int = 1024,
+ ):
+ self.remove_unused = remove_unused
+ self.constant_folding = constant_folding
+ self.constant_size = constant_size
+
+
+class NodePattern:
+ """
+ Class defining a matching pattern able to find nodes in a set of nodes.
+ """
+
+ def __init__(
+ self,
+ index: Optional[int] = None,
+ op_type: Optional[str] = None,
+ name: Optional[None] = None,
+ ):
+ self.index = index
+ self.op_type = op_type
+ self.name = name
+
+ def __repr__(self):
+ "usual"
+ args = ["index", "op_type", "name"]
+ sargs = []
+ for a in args:
+ if a:
+ sargs.append(f"{a}={getattr(self, a)!r}")
+ return f"{self.__class__.__name__}({', '.join(sargs)})"
+
+ def find(self, graph: "GraphBuilder") -> Iterator:
+ """
+ Iterates on nodes matching the pattern.
+ """
+ for index, node in enumerate(graph.nodes):
+ if self.match(index, node):
+ yield node
+
+ def match(self, index, node: NodeProto) -> bool:
+ """
+ Tells if a node is matching this pattern.
+ """
+ if self.index is not None and self.index != index:
+ return False
+ if self.op_type is not None and self.op_type != node.op_type:
+ return False
+ if self.name is not None and self.name != node.name:
+ return False
+ return True
+
+
+class Opset:
+ # defined for opset >= 18
+ # name: number of expected outputs
+ _implemented = {
+ "Add": 1,
+ "And": 1,
+ "Cast": 1,
+ "Concat": 1,
+ "Constant": 1,
+ "Div": 1,
+ "Exp": 1,
+ "Expand": 1,
+ "GatherElements": 1,
+ "Gemm": 1,
+ "Identity": 1,
+ "MatMul": 1,
+ "MaxPool": 2,
+ "Mul": 1,
+ "Log": 1,
+ "Or": 1,
+ "Pow": 1,
+ "Relu": 1,
+ "ReduceSum": 1,
+ "Reshape": 1,
+ "Shape": 1,
+ "Slice": 1,
+ "Squeeze": 1,
+ "Sub": 1,
+ "Transpose": 1,
+ "Unsqueeze": 1,
+ }
+
+ def __init__(self, builder: "GraphBuilder", opset: int):
+ self.opset = opset
+ self.builder = builder
+
+ def __getattr__(self, name):
+ if name in self._implemented:
+ return partial(self.make_node, name)
+ try:
+ return super().__getattr__(name)
+ except AttributeError as e:
+ raise AttributeError(f"Unable to access attribute {name!r}.") from e
+
+ def Initializer(
+ self, init: Union[TensorProto, np.ndarray], name: Optional[str] = None
+ ) -> str:
+ """
+ Creates an initializer.
+
+ :param init: value
+ :param name: name if value is not a TensorProto
+ :return: its name
+ """
+ return self.builder.make_initializer(init, name=name, exists=True)
+
+ def make_node(
+ self,
+ op_type: str,
+ *inputs: Optional[Union[str, List[str]]],
+ outputs: Optional[Union[int, List[str], str]] = None,
+ domain: str = "",
+ **kwargs,
+ ):
+ if outputs is None:
+ outputs = self._implemented[op_type]
+ if inputs is None:
+ inputs = []
+ new_inputs = []
+ for i in inputs:
+ if not isinstance(i, str):
+ name = self.builder.unique_name("cst")
+ self.builder.make_initializer(i, name=name, exists=True)
+ new_inputs.append(name)
+ else:
+ new_inputs.append(i)
+
+ return self.builder.make_node(
+ op_type, new_inputs, outputs=outputs, domain=domain, **kwargs
+ )
+
+
+class GraphBuilder:
+ def __init__(
+ self,
+ target_opset_or_existing_proto: Optional[
+ Union[int, Dict[str, int], ModelProto, FunctionProto]
+ ] = None,
+ input_names: Optional[Sequence[str]] = None,
+ as_function: bool = False,
+ optimization_options: Optional[OptimizationOptions] = None,
+ args: Optional[List[Any]] = None,
+ verbose: int = 0,
+ ir_version: Optional[int] = None,
+ ):
+ self.optimization_options = optimization_options or OptimizationOptions()
+ self.as_function = as_function
+ self.input_args = args
+ self.verbose = verbose
+
+ if target_opset_or_existing_proto is None:
+ target_opset_or_existing_proto = onnx_opset_version() - 1
+ if isinstance(target_opset_or_existing_proto, (int, dict)):
+ self.opsets = (
+ {"": target_opset_or_existing_proto}
+ if isinstance(target_opset_or_existing_proto, int)
+ else target_opset_or_existing_proto
+ )
+ self.ir_version = ir_version
+ self.nodes = []
+ self.initializers_dict = {}
+ self.inputs = []
+ self.outputs = []
+ self._unique_names = set()
+ self.input_names = input_names or []
+ self.current_input = 0
+ self._known_shapes = {}
+ self._known_types = {}
+ self.constants_ = {}
+ self.functions_ = {}
+ elif isinstance(target_opset_or_existing_proto, ModelProto):
+ assert (
+ not input_names
+ ), "input_names must be empty if the input is an existing model."
+ proto = target_opset_or_existing_proto
+ self.opsets = {d.domain: d.version for d in proto.opset_import}
+ self.ir_version = ir_version or target_opset_or_existing_proto.ir_version
+ self.nodes = list(proto.graph.node)
+ self.initializers_dict = {i.name: i for i in proto.graph.initializer}
+ self.initializers_dict.update(
+ {i.name: i for i in proto.graph.sparse_initializer}
+ )
+ self.inputs = list(proto.graph.input)
+ self.outputs = list(proto.graph.output)
+ self.input_names = [i.name for i in proto.graph.input]
+ self.current_input = len(self.inputs)
+ # This should be improve.
+ self._known_shapes = {}
+ self._known_types = {}
+ self.constants_ = {}
+ for k, v in self.initializers_dict.items():
+ self.constants_[k] = None
+ self.set_shape(k, self._get_tensor_shape(v))
+ self.set_type(k, self._get_tensor_type(v))
+ for node in self.nodes:
+ if node.op_type == "Constant":
+ self.constants_[node.output[0]] = node
+ self.set_shape(node.output[0], self._get_tensor_shape(node))
+ self.set_type(node.output[0], self._get_tensor_type(node))
+ for f in proto.functions:
+ self.add_function(f)
+ else:
+ raise NotImplementedError(
+ f"{type(target_opset_or_existing_proto)} is not supported."
+ )
+
+ self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None
+ self._cache_array = []
+
+ def add_local_function(self, domain: str, name: str, gr: "GraphBuilder"):
+ "Adds a local function."
+ assert (
+ domain,
+ name,
+ ) not in self.functions_, f"Function {(domain, name)} was already added."
+ self.functions_[domain, name] = gr
+
+ def _get_tensor_shape(
+ self, proto: Union[NodeProto, TensorProto]
+ ) -> Tuple[int, ...]:
+ if isinstance(proto, TensorProto):
+ return tuple(proto.dims)
+ if isinstance(proto, NodeProto):
+ for att in proto.attribute:
+ if att.name == "value_float":
+ return tuple()
+ if att.name == "value_int":
+ return tuple()
+ if att.name == "value_floats":
+ return tuple(att.floats)
+ if att.name == "value_ints":
+ return (len(att.ints),)
+ if att.name == "value":
+ t = onh.to_array(att.t)
+ return t.shape
+ raise TypeError(
+ f"Unexpected or unsupported scenario type {type(proto)}: {proto}."
+ )
+
+ def _get_tensor_type(self, proto: Union[NodeProto, TensorProto]) -> int:
+ if isinstance(proto, TensorProto):
+ return proto.data_type
+ if isinstance(proto, NodeProto):
+ for att in proto.attribute:
+ if att.name == "value_float":
+ return TensorProto.FLOAT
+ if att.name == "value_int":
+ return TensorProto.INT64
+ if att.name == "value_floats":
+ return TensorProto.FLOAT
+ if att.name == "value_ints":
+ return TensorProto.INT64
+ if att.name == "value":
+ t = onh.to_array(att.t)
+ return oh.np_dtype_to_tensor_dtype(t.dtype)
+ raise ValueError(f"Unexpected type or value {type(proto)}: {proto}.")
+
+ def is_constant(self, name: str) -> bool:
+ """Tells if a result is a constant."""
+ return name in self.constants_
+
+ def get_constant(self, name: str) -> np.ndarray:
+ assert self.is_constant(name), f"Result {name!r} is not a constant."
+ assert (
+ name in self.initializers_dict
+ ), f"Result {name!r} was never evaluated within method 'constant_folding'."
+ value = self.initializers_dict[name]
+ if isinstance(value, np.ndarray):
+ return value
+
+ raise TypeError(f"Unable to convert type {type(value)} into numpy array.")
+
+ def set_shape(self, name: str, shape: Tuple[int, ...]):
+ assert isinstance(
+ name, str
+ ), f"Unexpected type {type(name)} for name, it should be a string."
+ if name in self._known_shapes:
+ assert shape == self._known_shapes[name], (
+ f"Name {name!r} already exists and it is different "
+ f"{self._known_shapes[name]} != {shape}"
+ )
+ return
+ assert isinstance(
+ shape, tuple
+ ), f"Unexpected shape type {type(shape)}, it should be a tuple."
+ self._known_shapes[name] = shape
+
+ def set_type(self, name: str, dtype: int):
+ assert isinstance(name, str), f"Unexpected type {type(name)} for name."
+ int_type = dtype if isinstance(dtype, int) else self._get_type(dtype)
+ if name in self._known_types:
+ assert int_type == self._known_types[name], (
+ f"Name {name!r} already exists and it is different "
+ f"{self._known_types[name]} != {int_type}."
+ )
+ self._known_types[name] = int_type
+
+ def rank(self, name: str) -> int:
+ return len(self.get_shape(name))
+
+ def has_shape(self, name: str) -> bool:
+ return name in self._known_shapes
+
+ def get_shape(self, name: str) -> int:
+ assert name in self._known_shapes, (
+ f"Shape is unknown for result {name!r}, "
+ f"known_shapes={self._known_shapes}."
+ )
+ return self._known_shapes[name]
+
+ def has_type(self, name: str) -> bool:
+ return name in self._known_types
+
+ def get_type(self, name: str) -> int:
+ assert (
+ name in self._known_types
+ ), f"Type is unknown for result {name!r}, known_types={self._known_types}."
+ return self._known_types[name]
+
+ def unique_name(self, prefix: str) -> str:
+ if prefix in self._unique_names:
+ i = 2
+ sug = f"{prefix}2"
+ while sug in self._unique_names:
+ i += 1
+ sug = f"{prefix}{i}"
+ self._unique_names.add(sug)
+ return sug
+ self._unique_names.add(prefix)
+ return prefix
+
+ def _prepare_inputs(self, schema: Optional[Any], *inputs: List[Any]) -> List[str]:
+ input_names = []
+ for i in inputs:
+ self.make_input(i.name, i.dtype, i.shape)
+ input_names.append(i.name)
+ return input_names
+
+ def _get_type(self, elem_type: Any, exc: bool = True) -> int:
+ if not isinstance(elem_type, int):
+ st = str(elem_type)
+ if "float32" in st:
+ elem_type = TensorProto.FLOAT
+ elif "int64" in st:
+ elem_type = TensorProto.INT64
+ elif elem_type is None:
+ elem_type = TensorProto.UNDEFINED
+ elif exc:
+ raise ValueError(f"Unable to interpret elem_type {elem_type!r}.")
+ return elem_type
+
+ def make_initializer(
+ self, value: Any, name: str = "", external: bool = False, exists: bool = False
+ ) -> str:
+ if external:
+ raise NotImplementedError("External initializers are not implemented yet.")
+ if name == "":
+ if exists:
+ raise ValueError("Undefined name cannot exist.")
+ name = self.unique_name("cst")
+ elif not exists:
+ if name in self._unique_names:
+ raise ValueError(f"{name!r} is already assigned.")
+ self._unique_names.add(name)
+ self.set_shape(name, value.shape)
+ self.set_type(name, self._get_type(value.dtype))
+ self.initializers_dict[name] = value
+ self.constants_[name] = None
+ if self.verbose and np.prod(value.shape) > 100:
+ print(
+ f"[GraphBuilder] make_initializer:{name}[{value.dtype}:{value.shape}]"
+ )
+ return name
+
+ def make_tensor_input(
+ self, name: str, elem_type: Any, shape: Tuple[int, ...]
+ ) -> str:
+ if self.current_input < len(self.input_names):
+ # The input needs to be renamed, an identity node is added.
+ input_name = self.input_names[self.current_input]
+ self.make_node("Identity", [input_name], [name])
+ else:
+ self.input_names.append(name)
+ input_name = name
+ if name in self._unique_names:
+ raise ValueError(f"{name!r} is already assigned.")
+ self._unique_names.add(name)
+ self.current_input += 1
+ elem_type = self._get_type(elem_type)
+ self.inputs.append(oh.make_tensor_value_info(input_name, elem_type, shape))
+ if self.verbose:
+ print(f"[GraphBuilder] make_tensor_input:{name}[{elem_type}:{shape}]")
+ if shape:
+ self.set_shape(name, shape)
+ if elem_type:
+ self.set_type(name, elem_type)
+ return name
+
+ def make_tensor_output(
+ self,
+ name: Union[str, List[str]],
+ elem_type: Optional[int] = None,
+ shape: Optional[Tuple[int, ...]] = None,
+ is_dimension: bool = False,
+ indexed: bool = False,
+ ) -> Union[str, List[str]]:
+ if isinstance(name, list):
+ res = []
+ for n in name:
+ res.append(self.make_tensor_output(n, elem_type, shape))
+ return res
+
+ elem_type = self._get_type(elem_type, False)
+ assert (
+ self.as_function or elem_type != 0
+ ), f"Undefined element type for {name!r}."
+ self.outputs.append(oh.make_tensor_value_info(name, elem_type, shape))
+ if self.verbose:
+ print(f"[GraphBuilder] make_tensor_output:{name}[{elem_type}:{shape}]")
+ if shape:
+ self.set_shape(name, shape)
+ if elem_type:
+ self.set_type(name, elem_type)
+ return name
+
+ def make_node(
+ self,
+ op_type: str,
+ inputs: Union[str, List[str]],
+ outputs: Union[int, List[str], str] = 1,
+ domain: str = "",
+ attributes: Optional[List[AttributeProto]] = None,
+ **kwargs,
+ ) -> Union[str, List[str]]:
+ assert (
+ not kwargs or not attributes
+ ), f"Only attributes or kwargs can be filled for node {op_type!r}."
+ if isinstance(inputs, tuple):
+ inputs = list(inputs)
+ if isinstance(outputs, int):
+ assert outputs > 0, f"outputs={outputs} must be > 0."
+ lower = op_type.lower()
+ output_names = [
+ self.unique_name(f"_onx_{lower}{i}") for i in range(outputs)
+ ]
+ elif isinstance(outputs, str):
+ output_names = [outputs]
+ else:
+ output_names = outputs
+ if isinstance(inputs, str):
+ inputs = [inputs]
+
+ # next
+ try:
+ node = oh.make_node(op_type, inputs, output_names, domain=domain, **kwargs)
+ except TypeError as e:
+ raise TypeError(
+ f"A node {op_type!r} cannot be created with "
+ f"inputs={inputs} (types={[type(i) for i in inputs]}), "
+ f"outputs={outputs} "
+ f"(types={[type(o) for o in outputs] if isinstance(outputs, (tuple, list)) else outputs}), " # noqa: E501
+ f"domain={domain!r}, kwargs={kwargs}."
+ ) from e
+ if attributes:
+ node.attribute.extend(attributes)
+
+ # constant handling, shape, type
+ if node.op_type == "Constant":
+ size = len(node.SerializeToString())
+ assert size < self.optimization_options.constant_size, (
+ f"A node Constant holds a tensor bigger than "
+ f"the constant: {size} >= {self.constant_size}."
+ )
+ k = node.output[0]
+ self.constants_[k] = node
+ shape = self._get_tensor_shape(node)
+ dtype = self._get_tensor_type(node)
+ self.set_shape(k, shape)
+ self.set_type(k, dtype)
+ if self.verbose and np.prod(shape) > 100:
+ print(f"[GraphBuilder] make_constant:{k}[{dtype}:{shape}]")
+ elif node.op_type == "Identity":
+ if node.input[0] in self._known_shapes:
+ self.set_shape(node.output[0], self._known_shapes[node.input[0]])
+ if node.input[0] in self._known_types:
+ self.set_type(node.output[0], self._known_types[node.input[0]])
+ if self.is_constant(node.input[0]):
+ self.constants_[node.output[0]] = node
+ else:
+ if all(map(self.is_constant, node.input)):
+ for o in node.output:
+ self.constants_[o] = node
+
+ # add the node
+ self.nodes.append(node)
+ if len(output_names) == 1:
+ return output_names[0]
+ return output_names
+
+ def make_nodes(
+ self,
+ builder: "GraphBuilder",
+ input_names: List[str],
+ output_names: List[str],
+ prefix: str = "",
+ ) -> Union[str, List[str]]:
+ """
+ Appends all nodes and initializers from another builder.
+ Handles the renaming of results.
+ The content stored in 'builder' is modified inplace to avoid copying.
+
+ :param builder: other builder
+ :param input_names: input names
+ :param output_names: output names
+ :param prefix: prefix all name from this builder
+ :return: output names
+ """
+ renaming = {}
+ for init, value in builder.initializers_dict.items():
+ name = self.unique_name(f"{prefix}{init}")
+ renaming[init] = name
+ if isinstance(value, TensorProto):
+ value.name = name
+ self.initializers_dict[name] = value
+ self.constants_[name] = None
+ self.set_shape(name, builder._known_shapes[init])
+ self.set_type(name, builder._known_types[init])
+
+ assert len(input_names) == len(builder.inputs), (
+ f"Inconsistency between input_names={input_names} "
+ f"and the other builder inputs={builder.inputs}."
+ )
+
+ for name, inp in zip(input_names, builder.inputs):
+ new_name = self.unique_name(f"{prefix}{inp.name}")
+ renaming[inp.name] = new_name
+ if builder.has_shape(inp.name):
+ self.set_shape(new_name, builder.get_shape(inp.name))
+ if builder.has_type(inp.name):
+ self.set_type(new_name, builder.get_type(inp.name))
+ self.make_node("Identity", [name], [new_name])
+
+ for node in builder.nodes:
+ new_inputs = [renaming[i] for i in node.input]
+ new_outputs = [self.unique_name(f"{prefix}{o}") for o in node.output]
+ for o, no in zip(node.output, new_outputs):
+ renaming[o] = no
+ self.make_node(
+ node.op_type,
+ new_inputs,
+ new_outputs,
+ domain=node.domain,
+ attributes=node.attribute,
+ )
+ for o, no in zip(node.output, new_outputs):
+ if builder.has_shape(o):
+ self.set_shape(no, builder.get_shape(o))
+ if builder.has_type(o):
+ self.set_type(no, builder.get_type(o))
+
+ assert len(output_names) == len(builder.outputs), (
+ f"Inconsistency between output_names={output_names} and "
+ f"outputs={builder.outputs}, renaming={renaming}."
+ )
+ for name, out in zip(output_names, builder.outputs):
+ self.make_node("Identity", [renaming[out.name]], [name])
+
+ # opsets and domains
+ for o, v in builder.opsets.items():
+ if o in self.opsets:
+ assert self.opsets[o] == builder.opsets[o], (
+ f"Opset mismatch for domain {o!r}, "
+ f"{self.opsets[o]} != {builder.opsets[o]}."
+ )
+ continue
+ self.opsets[o] = v
+
+ if len(output_names) == 1:
+ return output_names[0]
+ return output_names
+
+ def from_array(
+ self, arr: T, name: Optional[str] = None
+ ) -> TensorProto: # noqa: F821
+ if isinstance(arr, np.ndarray):
+ return self.from_np_array(arr, name)
+ raise NotImplementedError(
+ f"{type(arr)} is not supported yet but initializer {name or ''!r} is."
+ )
+
+ def from_np_array(self, arr: np.ndarray, name: Optional[str] = None) -> TensorProto:
+ arr_cpu = np.ascontiguousarray(arr) if not arr.flags["C_CONTIGUOUS"] else arr
+ if arr_cpu.ctypes.data == arr.ctypes.data:
+ if sys.byteorder == "big":
+ arr_cpu = arr_cpu.copy()
+ np.byteswap(
+ np.frombuffer(arr_cpu.ctypes.data, dtype=arr_cpu.dtype),
+ inplace=True,
+ )
+ else:
+ if sys.byteorder == "big":
+ np.byteswap(
+ np.frombuffer(arr_cpu.ctypes.data, dtype=arr_cpu.dtype),
+ inplace=True,
+ )
+ # let's the tensor until the builder is released
+ # so the pointer does not disappear
+ self._cache_array.append(arr_cpu)
+
+ tensor = TensorProto()
+ tensor.dims.extend(arr_cpu.shape)
+ tensor.name = name
+ tensor.data_type = self._get_type(arr_cpu.dtype)
+ # this does not work...
+ # tensor.raw_data = arr_cpu.ctypes.data
+ tensor.raw_data = arr_cpu.tobytes()
+ if self.verbose and np.prod(arr_cpu.shape) > 100:
+ print(
+ f"[GraphBuilder] from_array:{tensor.data_type}[{arr_cpu.shape}]:"
+ f"{'swapped' if sys.byteorder == 'big' else ''}"
+ )
+ return tensor
+
+ def _build_initializers(self) -> List[TensorProto]:
+ res = []
+ for k, v in sorted(self.initializers_dict.items()):
+ if isinstance(v, np.ndarray):
+ if np.prod(v.shape) > 100:
+ if self.verbose:
+ print(f"[GraphBuilder] from_array:{k}:{v.dtype}[{v.shape}]")
+ t = self.from_array(v, name=k)
+ else:
+ t = onh.from_array(v, name=k)
+ res.append(t)
+ continue
+ if isinstance(v, TensorProto):
+ res.append(v)
+ continue
+ raise TypeError(
+ f"Unable to convert initializer {k!r} with type "
+ f"{type(v)} into a TensorProto."
+ )
+ return res
+
+ def process(
+ self,
+ graph_module: Any,
+ interpreter: "Interpreter", # noqa: F821
+ ):
+ for node in graph_module.graph.nodes:
+ interpreter.run_node(node)
+
+ def to_onnx(
+ self, as_function: bool = False, optimize: bool = True
+ ) -> Union[FunctionProto, ModelProto]:
+ if optimize:
+ self.optimize()
+ if as_function:
+ raise NotImplementedError("Export as FunctionProto is not implemented yet.")
+ dense = self._build_initializers()
+ opsets = [oh.make_opsetid(*o) for o in self.opsets.items()]
+ if as_function:
+ return oh.make_function(
+ self.nodes,
+ self.name,
+ [i.name for i in self.inputs],
+ [o.name for o in self.outputs],
+ domain=self.domain,
+ )
+
+ if self.verbose:
+ print("[GraphBuilder] onh.make_graph")
+ graph = oh.make_graph(
+ self.nodes, "experiment", self.inputs, self.outputs, dense
+ )
+ if self.verbose:
+ print("[GraphBuilder] onh.make_model")
+ model = oh.make_model(graph, opset_imports=opsets)
+ if self.ir_version:
+ model.ir_version = self.ir_version
+ return model
+
+ def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]):
+ for i in node.input:
+ assert i in existing, (
+ f"Unknown input {i!r} from node {ind}:{node.op_type}:{node.name}. "
+ f"Known: {existing}."
+ )
+ for att in node.attribute:
+ if att.type == AttributeProto.GRAPH and att.g:
+ g_existing = existing.copy()
+ for i in att.g.input:
+ g_existing.add(i.name)
+ for ind2, node2 in enumerate(att.g.node):
+ self._check_order_node((ind, ind2), node2, g_existing)
+ for o in att.g.output:
+ assert (
+ o.name in g_existing
+ ), f"Unknown output {o.name!r}. Known: {g_existing}."
+ for o in node.output:
+ existing.add(o)
+
+ def check_order(self):
+ existing = set(self.initializers_dict)
+ for i in self.inputs:
+ existing.add(i.name)
+ for ind, node in enumerate(self.nodes):
+ self._check_order_node(ind, node, existing)
+ for o in self.outputs:
+ assert o.name in existing, f"Unknown output {o.name!r}. Known: {existing}."
+
+ def optimize(self, check_order: bool = False):
+ if check_order:
+ self.check_order()
+ self.remove_identity_nodes()
+ if check_order:
+ self.check_order()
+ if self.optimization_options.remove_unused:
+ self.remove_unused()
+ if check_order:
+ self.check_order()
+ if self.optimization_options.constant_folding:
+ self.constant_folding()
+ if check_order:
+ self.check_order()
+ if self.optimization_options.remove_unused:
+ self.remove_unused()
+ if check_order:
+ self.check_order()
+
+ def hidden_inputs_graph(self, graph: GraphProto) -> Set[str]:
+ hidden = set()
+ memo = set(i.name for i in graph.initializer)
+ memo |= set(i.name for i in graph.sparse_initializer)
+ for node in graph.node:
+ for i in node.input:
+ if i not in memo:
+ hidden.add(i)
+ for att in node.attribute:
+ if att.type == AttributeProto.GRAPH and att.g:
+ hid = self.hidden_inputs_graph(att.g)
+ less = set(h for h in hid if h not in memo)
+ hidden |= less
+ memo |= set(node.output)
+ return hidden
+
+ def remove_unused(self):
+ """
+ Simple function to remove unused nodes.
+ It does not look into subgraphs and assumes there is none.
+ Everything is done in one pass.
+ """
+
+ # mark outputs
+ marked = {o.name: set() for o in self.outputs}
+ for node in reversed(self.nodes):
+ used = False
+ for o in node.output:
+ if o in marked:
+ for i in node.input:
+ marked[o].add(i)
+ used = True
+ for att in node.attribute:
+ if att.type == AttributeProto.GRAPH and att.g:
+ hidden_inputs = self.hidden_inputs_graph(att.g)
+ for i in hidden_inputs:
+ marked[i] = set()
+ if used:
+ for i in node.input:
+ marked[i] = set()
+
+ # removed nodes
+ removed = set()
+ marked_set = set(marked)
+ for ind, node in enumerate(self.nodes):
+ if not (set(node.output) & marked_set):
+ removed.add(ind)
+
+ if self.verbose:
+ for k, v in self.initializers_dict.items():
+ if k not in marked:
+ v = self.initializers_dict[k]
+ print(f"[GraphBuilder] remove_initializer:{k}:{v.dtype}[{v.shape}]")
+ self.initializers_dict = {
+ k: v for k, v in self.initializers_dict.items() if k in marked
+ }
+ self.constants_ = {k: v for k, v in self.constants_.items() if k in marked}
+ self.nodes = [node for i, node in enumerate(self.nodes) if i not in removed]
+
+ def _apply_transpose(self, node: NodeProto, feeds: Dict[str, T]) -> T: # noqa: F821
+ perm = None
+ for att in node.attribute:
+ if att.name == "perm":
+ perm = tuple(att.ints)
+ break
+ assert perm, f"perm not here in node {node}"
+ return [np.transpose(feeds[node.input[0]], perm)]
+
+ def constant_folding(self):
+ """
+ Folds all constants. Constants are marked during the creation of the graph.
+ There is no need to propagate this information.
+ """
+ updates = {}
+ node_to_remove = set()
+ for _k, v in self.constants_.items():
+ if v is None:
+ # this is an initiliazer
+ continue
+ # a node
+ if all(map(self.is_constant, v.output)):
+ node_to_remove.add(tuple(v.output))
+ # node evaluation
+ if v.op_type == "Transpose":
+ # bypassing onnx.numpy_helper.from_array, too slow
+ feeds = {i: self.initializers_dict[i] for i in v.input}
+ output = self._apply_transpose(v, feeds)
+ else:
+ ref = ReferenceEvaluator(v)
+ feeds = {i: self.get_constant(i) for i in v.input}
+ output = ref.run(None, feeds)
+ for name, value in zip(v.output, output):
+ updates[name] = None
+ self.initializers_dict[name] = value
+ if self.verbose:
+ print(
+ f"[GraphBuilder] fold_constant:"
+ f"{v.op_type}:{name}[{value.dtype}:"
+ f"{value.shape}]:from:{','.join(sorted(feeds))}"
+ )
+
+ self.constants_.update(updates)
+ new_nodes = []
+ for node in self.nodes:
+ if tuple(node.output) in node_to_remove:
+ continue
+ new_nodes.append(node)
+ self.nodes = new_nodes
+
+ def remove_identity_nodes(self):
+ """
+ Removes identity nodes.
+ """
+ # first pass: detect replacements
+ new_nodes = []
+ input_names = set(i.name for i in self.inputs)
+ output_names = set(i.name for i in self.outputs)
+ replacements = {}
+ replacements_rev = {}
+ for node in self.nodes:
+ if node.op_type != "Identity":
+ new_nodes.append(node)
+ continue
+
+ if node.output[0] not in output_names:
+ old_name, new_name = node.output[0], node.input[0]
+ elif (
+ node.input[0] not in input_names
+ and node.input[0] not in output_names
+ and node.input[0] not in replacements
+ ):
+ old_name, new_name = node.input[0], node.output[0]
+ else:
+ new_nodes.append(node)
+ continue
+
+ # the new name can be set for replacements as well
+ if new_name in replacements:
+ new_name = replacements[new_name]
+ assert new_name not in replacements, (
+ f"Name {old_name!r} still in {replacements}, "
+ f"node.op_type={node.op_type!r}, "
+ f"node.input={node.input}, node.output={node.output}, "
+ f"input_names={input_names}, output_names={output_names}"
+ )
+ if old_name in replacements_rev:
+ old_old_name = replacements_rev[old_name]
+ replacements[old_old_name] = new_name
+ replacements_rev[new_name] = old_old_name
+ if old_name in replacements:
+ replacements[replacements[old_name]] = new_name
+ assert new_name not in replacements, (
+ f"Name {old_name!r} still in {replacements}, "
+ f"node.op_type={node.op_type!r}, "
+ f"node.input={node.input}, node.output={node.output}, "
+ f"input_names={input_names}, output_names={output_names}"
+ )
+ replacements[old_name] = new_name
+ replacements_rev[new_name] = old_name
+
+ # verification
+ for k, v in replacements.items():
+ assert v not in replacements, (
+ f"replacement {k}->{v} is not possible because of "
+ f"{v}->{replacements[v]}, old_name={old_name!r}, "
+ f"new_name={new_name!r}"
+ )
+
+ # second pass: replacements in initializer
+ for k, v in replacements.items():
+ if k in self.initializers_dict:
+ self.initializers_dict[v] = self.initializers_dict[k]
+ del self.initializers_dict[k]
+ assert self.constants_[v]
+ self.constants_[v] = self.constants_[k]
+ del self.constants_[k]
+
+ # third pass: replacements in node
+ self.nodes = []
+ for node in new_nodes:
+ repo = {o for o in node.output if o in replacements}
+ repi = {o for o in node.input if o in replacements}
+ if repi or repo:
+ new_inputs = [replacements.get(i, i) for i in node.input]
+ new_outputs = [replacements.get(i, i) for i in node.output]
+ new_node = oh.make_node(
+ node.op_type,
+ new_inputs,
+ new_outputs,
+ domain=node.domain,
+ name=node.name,
+ )
+ new_node.attribute.extend(node.attribute)
+ self.nodes.append(new_node)
+ else:
+ self.nodes.append(node)
+
+ def np(
+ self,
+ index: Optional[int] = None,
+ op_type: Optional[str] = None,
+ name: Optional[str] = None,
+ ) -> NodePattern:
+ """
+ Returns an instance of :class:`NodePattern
+ `.
+ """
+ return NodePattern(index=index, op_type=op_type, name=name)
+
+ def update_attribute(
+ self,
+ pat: NodePattern,
+ recursive: bool = False,
+ **kwargs: Dict[str, Any],
+ ) -> int:
+ """
+ Udates attributes for nodes matching the
+
+ :param pat: returned by method :meth:`GraphBuilder.np`
+ :param recursive: walk through subgraph
+ :param kwargs: attributes to modify
+ :return: number of modified nodes
+ """
+ assert not recursive, "recursive=True is not implemented."
+ modified = 0
+ for node in pat.find(self):
+ up = self.update_node(node, **kwargs)
+ if up:
+ modified += 1
+ return modified
+
+ DELETE = object()
+
+ def update_node(self, node: NodeProto, **kwargs) -> bool:
+ """
+ Updates attributes of a node proto.
+ Returns True if the node was updated.
+ """
+ processed = set()
+ modified = True
+ atts = []
+ for att in node.attribute:
+ if att.name in kwargs:
+ processed.add(att.name)
+ if kwargs[att.name] is GraphBuilder.DELETE:
+ continue
+ new_att = oh.make_attribute(att.name, kwargs[att.name])
+ assert new_att.type == att.type, (
+ f"Mismatch value for attribute {att.name!r} has type "
+ f"{att.type} but the new value leads to "
+ f"type={new_att.type}."
+ )
+ atts.append(new_att)
+ modified = True
+ continue
+ atts.append(att)
+ for k, v in kwargs.items():
+ if k in processed or v is GraphBuilder.DELETE:
+ continue
+ modified = True
+ new_att = oh.make_attribute(k, v)
+ atts.append(new_att)
+
+ if modified:
+ del node.attribute[:]
+ node.attribute.extend(atts)
+ return modified
diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py
index be6e9dd..83e8878 100644
--- a/onnx_array_api/light_api/__init__.py
+++ b/onnx_array_api/light_api/__init__.py
@@ -1,21 +1,21 @@
from typing import Dict, Optional
from onnx import ModelProto
-from .annotations import domain
+from ..annotations import domain
from .model import OnnxGraph, ProtoType
-from .translate import Translater
from .var import Var, Vars
-from .inner_emitter import InnerEmitter
def start(
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
+ ir_version: Optional[int] = None,
) -> OnnxGraph:
"""
Starts an onnx model.
:param opset: main opset version
:param opsets: others opsets as a dictionary
+ :param ir_version: specify the ir_version as well
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
A very simple model:
@@ -47,7 +47,7 @@ def start(
)
print(onx)
"""
- return OnnxGraph(opset=opset, opsets=opsets)
+ return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version)
def g() -> OnnxGraph:
@@ -56,62 +56,3 @@ def g() -> OnnxGraph:
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
"""
return OnnxGraph(proto_type=ProtoType.GRAPH)
-
-
-def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
- """
- Translates an ONNX proto into a code using :ref:`l-light-api`
- to describe the ONNX graph.
-
- :param proto: model to translate
- :param single_line: as a single line or not
- :param api: API to export into,
- default is `"light"` and this is handle by class
- :class:`onnx_array_api.light_api.emitter.Emitter`,
- another value is `"onnx"` which is the inner API implemented
- in onnx package.
- :return: code
-
- .. runpython::
- :showcode:
-
- from onnx_array_api.light_api import start, translate
-
- onx = (
- start()
- .vin("X")
- .reshape((-1, 1))
- .Transpose(perm=[1, 0])
- .rename("Y")
- .vout()
- .to_onnx()
- )
- code = translate(onx)
- print(code)
-
- The inner API from onnx packahe is also available.
-
- .. runpython::
- :showcode:
-
- from onnx_array_api.light_api import start, translate
-
- onx = (
- start()
- .vin("X")
- .reshape((-1, 1))
- .Transpose(perm=[1, 0])
- .rename("Y")
- .vout()
- .to_onnx()
- )
- code = translate(onx, api="onnx")
- print(code)
- """
- if api == "light":
- tr = Translater(proto)
- return tr.export(single_line=single_line, as_str=True)
- if api == "onnx":
- tr = Translater(proto, emitter=InnerEmitter())
- return tr.export(as_str=True)
- raise ValueError(f"Unexpected value {api!r} for api.")
diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py
index 8a995b3..1291594 100644
--- a/onnx_array_api/light_api/_op_var.py
+++ b/onnx_array_api/light_api/_op_var.py
@@ -1,5 +1,7 @@
from typing import List, Optional, Union
-from .annotations import AI_ONNX_ML, domain
+import numpy as np
+from ..reference import from_array_extended
+from ..annotations import AI_ONNX_ML, domain
class OpsVar:
@@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
def Celu(self, alpha: float = 1.0) -> "Var":
return self.make_node("Celu", self, alpha=alpha)
+ def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var":
+ if value is None:
+ return self.make_node("ConstantOfShape", self)
+ return self.make_node("ConstantOfShape", self, value=from_array_extended(value))
+
def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var":
return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)
@@ -307,6 +314,13 @@ def Selu(
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
return self.make_node("Shrink", self, bias=bias, lambd=lambd)
+ def Slice(
+ self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
+ ) -> "Var":
+ if steps is None:
+ return self.make_node("Slice", self, starts, ends, axes)
+ return self.make_node("Slice", self, starts, ends, axes, steps)
+
def Softmax(self, axis: int = -1) -> "Var":
return self.make_node("Softmax", self, axis=axis)
diff --git a/onnx_array_api/light_api/_op_vars.py b/onnx_array_api/light_api/_op_vars.py
index f4dee1c..4f30dbe 100644
--- a/onnx_array_api/light_api/_op_vars.py
+++ b/onnx_array_api/light_api/_op_vars.py
@@ -10,8 +10,10 @@ def BitShift(self, direction: str = "") -> "Var":
return self.make_node("BitShift", *self.vars_, direction=direction)
def CenterCropPad(self, axes: Optional[List[int]] = None) -> "Var":
- axes = axes or []
- return self.make_node("CenterCropPad", *self.vars_, axes=axes)
+ kwargs = {}
+ if axes is not None:
+ kwargs["axes"] = axes
+ return self.make_node("CenterCropPad", *self.vars_, **kwargs)
def Clip(
self,
@@ -27,12 +29,14 @@ def Col2Im(
pads: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
) -> "Var":
- dilations = dilations or []
- pads = pads or []
- strides = strides or []
- return self.make_node(
- "Col2Im", *self.vars_, dilations=dilations, pads=pads, strides=strides
- )
+ kwargs = {}
+ if dilations is not None:
+ kwargs["dilations"] = dilations
+ if pads is not None:
+ kwargs["pads"] = pads
+ if strides is not None:
+ kwargs["strides"] = strides
+ return self.make_node("Col2Im", *self.vars_, **kwargs)
def Compress(self, axis: int = 0) -> "Var":
return self.make_node("Compress", *self.vars_, axis=axis)
@@ -49,19 +53,17 @@ def Conv(
pads: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
) -> "Var":
- dilations = dilations or []
- kernel_shape = kernel_shape or []
- pads = pads or []
- strides = strides or []
+ kwargs = {}
+ if dilations is not None:
+ kwargs["dilations"] = dilations
+ if kernel_shape is not None:
+ kwargs["kernel_shape"] = kernel_shape
+ if pads is not None:
+ kwargs["pads"] = pads
+ if strides is not None:
+ kwargs["strides"] = strides
return self.make_node(
- "Conv",
- *self.vars_,
- auto_pad=auto_pad,
- dilations=dilations,
- group=group,
- kernel_shape=kernel_shape,
- pads=pads,
- strides=strides,
+ "Conv", *self.vars_, auto_pad=auto_pad, group=group, **kwargs
)
def ConvInteger(
@@ -73,19 +75,17 @@ def ConvInteger(
pads: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
) -> "Var":
- dilations = dilations or []
- kernel_shape = kernel_shape or []
- pads = pads or []
- strides = strides or []
+ kwargs = {}
+ if dilations is not None:
+ kwargs["dilations"] = dilations
+ if kernel_shape is not None:
+ kwargs["kernel_shape"] = kernel_shape
+ if pads is not None:
+ kwargs["pads"] = pads
+ if strides is not None:
+ kwargs["strides"] = strides
return self.make_node(
- "ConvInteger",
- *self.vars_,
- auto_pad=auto_pad,
- dilations=dilations,
- group=group,
- kernel_shape=kernel_shape,
- pads=pads,
- strides=strides,
+ "ConvInteger", *self.vars_, auto_pad=auto_pad, group=group, **kwargs
)
def ConvTranspose(
@@ -99,23 +99,21 @@ def ConvTranspose(
pads: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
) -> "Var":
- dilations = dilations or []
- kernel_shape = kernel_shape or []
- output_padding = output_padding or []
- output_shape = output_shape or []
- pads = pads or []
- strides = strides or []
- return self.make_node(
- "ConvTranspose",
- *self.vars_,
- auto_pad=auto_pad,
- dilations=dilations,
- group=group,
- kernel_shape=kernel_shape,
- output_padding=output_padding,
- output_shape=output_shape,
- pads=pads,
- strides=strides,
+ kwargs = {}
+ if dilations is not None:
+ kwargs["dilations"] = dilations
+ if kernel_shape is not None:
+ kwargs["kernel_shape"] = kernel_shape
+ if pads is not None:
+ kwargs["pads"] = pads
+ if strides is not None:
+ kwargs["strides"] = strides
+ if output_padding is not None:
+ kwargs["output_padding"] = output_padding
+ if output_shape is not None:
+ kwargs["output_shape"] = output_shape
+ return self.make_node(
+ "ConvTranspose", *self.vars_, auto_pad=auto_pad, group=group, **kwargs
)
def CumSum(self, exclusive: int = 0, reverse: int = 0) -> "Var":
@@ -137,19 +135,17 @@ def DeformConv(
pads: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
) -> "Var":
- dilations = dilations or []
- kernel_shape = kernel_shape or []
- pads = pads or []
- strides = strides or []
+ kwargs = {}
+ if dilations is not None:
+ kwargs["dilations"] = dilations
+ if kernel_shape is not None:
+ kwargs["kernel_shape"] = kernel_shape
+ if pads is not None:
+ kwargs["pads"] = pads
+ if strides is not None:
+ kwargs["strides"] = strides
return self.make_node(
- "DeformConv",
- *self.vars_,
- dilations=dilations,
- group=group,
- kernel_shape=kernel_shape,
- offset_group=offset_group,
- pads=pads,
- strides=strides,
+ "DeformConv", *self.vars_, group=group, offset_group=offset_group, **kwargs
)
def DequantizeLinear(self, axis: int = 1) -> "Var":
@@ -206,12 +202,11 @@ def MatMulInteger(
def MaxRoiPool(
self, pooled_shape: Optional[List[int]] = None, spatial_scale: float = 1.0
) -> "Var":
- pooled_shape = pooled_shape or []
+ kwargs = {}
+ if pooled_shape is not None:
+ kwargs["pooled_shape"] = pooled_shape
return self.make_node(
- "MaxRoiPool",
- *self.vars_,
- pooled_shape=pooled_shape,
- spatial_scale=spatial_scale,
+ "MaxRoiPool", *self.vars_, spatial_scale=spatial_scale, **kwargs
)
def MaxUnpool(
@@ -220,16 +215,14 @@ def MaxUnpool(
pads: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
) -> "Var":
- kernel_shape = kernel_shape or []
- pads = pads or []
- strides = strides or []
- return self.make_node(
- "MaxUnpool",
- *self.vars_,
- kernel_shape=kernel_shape,
- pads=pads,
- strides=strides,
- )
+ kwargs = {}
+ if kernel_shape is not None:
+ kwargs["kernel_shape"] = kernel_shape
+ if pads is not None:
+ kwargs["pads"] = pads
+ if strides is not None:
+ kwargs["strides"] = strides
+ return self.make_node("MaxUnpool", *self.vars_, **kwargs)
def MelWeightMatrix(self, output_datatype: int = 1) -> "Var":
return self.make_node(
@@ -269,19 +262,17 @@ def QLinearConv(
pads: Optional[List[int]] = None,
strides: Optional[List[int]] = None,
) -> "Var":
- dilations = dilations or []
- kernel_shape = kernel_shape or []
- pads = pads or []
- strides = strides or []
+ kwargs = {}
+ if kernel_shape is not None:
+ kwargs["kernel_shape"] = kernel_shape
+ if pads is not None:
+ kwargs["pads"] = pads
+ if strides is not None:
+ kwargs["strides"] = strides
+ if dilations is not None:
+ kwargs["dilations"] = dilations
return self.make_node(
- "QLinearConv",
- *self.vars_,
- auto_pad=auto_pad,
- dilations=dilations,
- group=group,
- kernel_shape=kernel_shape,
- pads=pads,
- strides=strides,
+ "QLinearConv", *self.vars_, auto_pad=auto_pad, group=group, **kwargs
)
def QLinearMatMul(
@@ -305,7 +296,9 @@ def RandomNormal(
seed: float = 0.0,
shape: Optional[List[int]] = None,
) -> "Var":
- shape = shape or []
+ kwargs = {}
+ if shape is not None:
+ kwargs["shape"] = shape
return self.make_node(
"RandomNormal",
*self.vars_,
@@ -313,7 +306,7 @@ def RandomNormal(
mean=mean,
scale=scale,
seed=seed,
- shape=shape,
+ **kwargs,
)
def RandomUniform(
@@ -324,7 +317,9 @@ def RandomUniform(
seed: float = 0.0,
shape: Optional[List[int]] = None,
) -> "Var":
- shape = shape or []
+ kwargs = {}
+ if shape is not None:
+ kwargs["shape"] = shape
return self.make_node(
"RandomUniform",
*self.vars_,
@@ -332,7 +327,7 @@ def RandomUniform(
high=high,
low=low,
seed=seed,
- shape=shape,
+ **kwargs,
)
def Range(
@@ -439,12 +434,13 @@ def Resize(
mode: str = "nearest",
nearest_mode: str = "round_prefer_floor",
) -> "Var":
- axes = axes or []
+ kwargs = {}
+ if axes is not None:
+ kwargs["axes"] = axes
return self.make_node(
"Resize",
*self.vars_,
antialias=antialias,
- axes=axes,
coordinate_transformation_mode=coordinate_transformation_mode,
cubic_coeff_a=cubic_coeff_a,
exclude_outside=exclude_outside,
@@ -452,6 +448,7 @@ def Resize(
keep_aspect_ratio_policy=keep_aspect_ratio_policy,
mode=mode,
nearest_mode=nearest_mode,
+ **kwargs,
)
def RoiAlign(
diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py
deleted file mode 100644
index f5d5e4d..0000000
--- a/onnx_array_api/light_api/inner_emitter.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from typing import Any, Dict, List, Tuple
-from onnx import AttributeProto
-from .annotations import ELEMENT_TYPE_NAME
-from .emitter import BaseEmitter
-from .translate import Translater
-
-
-class InnerEmitter(BaseEmitter):
- """
- Converts event into proper code.
- """
-
- def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
- """
- Renders an attribute value into a string.
-
- :param value: value to converter
- :return: rows to append before, actual value
- """
- if value[0].type == AttributeProto.GRAPH:
- tr = Translater(value[0].g, emitter=self)
- rows = tr.export(as_str=False, single_line=False)
- new_rows = [f"def _make_local_graph_{value[0].name}():"]
- for line in rows:
- if "make_model" in line:
- break
- new_rows.append(" " + line)
- new_rows.append(" return graph")
- new_rows.append(f"{value[0].name} = _make_local_graph_{value[0].name}()")
- return new_rows, value[0].name
-
- return super().render_attribute_value(value)
-
- def join(self, rows: List[str], single_line: bool = False) -> str:
- "Returns the separators. `single_line` is unused."
- return "\n".join(rows)
-
- def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
- lines = ["opset_imports = ["]
- opsets = kwargs.get("opsets", {})
- for k, v in opsets.items():
- lines.append(f" make_opsetid({k!r}, {v!r}),")
- lines.append("]")
- return lines
-
- def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]:
- lines = [
- "model = make_model(",
- " graph,",
- " functions=functions,",
- " opset_imports=opset_imports",
- ")",
- ]
- return lines
-
- def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
- lines = [
- "inputs = []",
- "outputs = []",
- "nodes = []",
- "initializers = []",
- "sparse_initializers = []",
- "functions = []",
- ]
- return lines
-
- def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
- name = kwargs.get("name", "noname")
- lines = [
- "graph = make_graph(",
- " nodes,",
- f" {name!r},",
- " inputs,",
- " outputs,",
- " initializers,",
- " sparse_initializer=sparse_initializers,",
- ")",
- ]
- return lines
-
- def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
- name = kwargs["name"]
- value = kwargs["value"]
- repl = {"bool": "bool_", "object": "object_", "str": "str_"}
- sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
- return [
- "initializers.append(",
- " from_array(",
- f" np.array({value.tolist()}, dtype=np.{sdtype}),",
- f" name={name!r}",
- " )",
- ")",
- ]
-
- def _emit_io(self, container: str, **kwargs: Dict[str, Any]) -> List[str]:
- name = kwargs["name"]
- elem_type = kwargs.get("elem_type", None)
- shape = kwargs.get("shape", None)
- if elem_type and shape:
- return [
- f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r}))"
- ]
- if elem_type:
- return [
- f"{container}.append(make_tensor_value_info({name!r}, TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape=[]))"
- ]
- return [
- f"{container}.append(make_tensor_value_info({name!r}, TensorProto.UNDEFINED, []))"
- ]
-
- def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
- return self._emit_io("inputs", **kwargs)
-
- def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
- return self._emit_io("outputs", **kwargs)
-
- def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
- op_type = kwargs["op_type"]
- inputs = kwargs["inputs"]
- outputs = kwargs["outputs"]
- if kwargs.get("domain", "") != "":
- domain = kwargs["domain"]
-
- before_lines = []
- lines = [
- "nodes.append(",
- " make_node(",
- f" {op_type!r},",
- f" {inputs},",
- f" {outputs},",
- ]
- domain = kwargs.get("domain", "")
- if domain:
- lines.append(f" domain={domain!r},")
- atts = kwargs.get("atts", {})
- for k, v in atts.items():
- before, value = self.render_attribute_value(v)
- before_lines.extend(before)
- lines.append(f" {k}={value},")
- lines[-1] = lines[-1][:-1]
- lines.extend([" )", ")"])
- return before_lines + lines
diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py
index 67fc18e..f6770eb 100644
--- a/onnx_array_api/light_api/model.py
+++ b/onnx_array_api/light_api/model.py
@@ -14,7 +14,7 @@
)
from onnx.numpy_helper import from_array
from ..ext_test_case import is_azure, is_windows
-from .annotations import (
+from ..annotations import (
elem_type_int,
make_shape,
GRAPH_PROTO,
@@ -42,6 +42,7 @@ class OnnxGraph:
:param opset: main opset version
:param opsets: other opsets as a dictionary
+ :param ir_version: to specify an ir_version
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
"""
@@ -49,6 +50,7 @@ def __init__(
self,
opset: Optional[int] = None,
opsets: Optional[Dict[str, int]] = None,
+ ir_version: Optional[int] = None,
proto_type: ProtoType = ProtoType.MODEL,
):
if opsets is not None and "" in opsets:
@@ -65,6 +67,7 @@ def __init__(
self.proto_type = proto_type
self.opsets = opsets
self.opset = opset
+ self.ir_version = ir_version
self.nodes: List[Union[NodeProto, TensorProto]] = []
self.inputs: List[ValueInfoProto] = []
self.outputs: List[ValueInfoProto] = []
@@ -180,6 +183,8 @@ def make_output(
:param elem_type: element type (the input is assumed to be a tensor)
:param shape: shape
:return: an instance of ValueInfoProto
+
+ If the checker fails, try `shape=[]`.
"""
if not self.has_name(name):
raise ValueError(f"Name {name!r} does not exist.")
@@ -314,7 +319,8 @@ def rename(self, old_name: str, new_name: str):
value = self.unique_names_[old_name]
if isinstance(value, int):
raise TypeError(
- f"Unexpected type {type(value)} for value {old_name!r} renamed into {new_name!r}."
+ f"Unexpected type {type(value)} for value {old_name!r} "
+ f"renamed into {new_name!r}."
)
self.unique_names_[new_name] = value
self.renames_[old_name] = new_name
@@ -332,7 +338,7 @@ def _fix_name_tensor_input(
) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]:
obj = self._fix_name_tensor(obj)
shape = make_shape(obj.type.tensor_type.shape)
- if shape is None:
+ if not shape:
tensor_type_proto = make_tensor_type_proto(
obj.type.tensor_type.elem_type, []
)
@@ -344,7 +350,7 @@ def _fix_name_tensor_output(
) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]:
obj = self._fix_name_tensor(obj)
shape = make_shape(obj.type.tensor_type.shape)
- if shape is None:
+ if not shape:
tensor_type_proto = make_tensor_type_proto(
obj.type.tensor_type.elem_type, []
)
@@ -400,6 +406,8 @@ def to_onnx(self) -> GRAPH_PROTO:
# If no opsets, it a subgraph, not a model.
return graph
model = make_model(graph, opset_imports=opsets)
+ if self.ir_version:
+ model.ir_version = self.ir_version
if not is_windows() or not is_azure():
# check_model fails sometimes on Windows
check_model(model)
diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py
index 882dcb7..72a9533 100644
--- a/onnx_array_api/light_api/var.py
+++ b/onnx_array_api/light_api/var.py
@@ -3,7 +3,7 @@
import numpy as np
from onnx import TensorProto
from onnx.defs import get_schema
-from .annotations import (
+from ..annotations import (
elem_type_int,
make_shape,
ELEMENT_TYPE,
@@ -193,7 +193,7 @@ def make_node(
)
if len(names) == 1:
return Var(self.parent, names[0])
- return Vars(self.parent, *list(map(lambda v: Var(self.parent, v), names)))
+ return Vars(self.parent, *[Var(self.parent, v) for v in names])
def vin(
self,
@@ -318,6 +318,8 @@ def vout(
:param elem_type: element_type
:param shape: shape
:return: instance of :class:`onnx_array_api.light_api.Var`
+
+ If the checker fails, try `shape=[]`.
"""
output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape)
return Var(
@@ -443,7 +445,8 @@ def rename(self, *new_names: List[str]) -> "Vars":
"Renames variables."
if len(new_names) != len(self):
raise ValueError(
- f"Vars has {len(self)} elements but the method received {len(new_names)} names."
+ f"Vars has {len(self)} elements but the method "
+ f"received {len(new_names)} names."
)
new_vars = []
for var, name in zip(self.vars_, new_names):
@@ -461,6 +464,8 @@ def vout(
:param elem_type_shape: list of tuple(element_type, shape)
:return: instance of :class:`onnx_array_api.light_api.Vars`
+
+ If the checker fails, try `shape=[]`.
"""
vars = []
for i, v in enumerate(self.vars_):
diff --git a/onnx_array_api/npx/npx_array_api.py b/onnx_array_api/npx/npx_array_api.py
index 142a892..a9fb3d6 100644
--- a/onnx_array_api/npx/npx_array_api.py
+++ b/onnx_array_api/npx/npx_array_api.py
@@ -10,8 +10,6 @@ class ArrayApiError(RuntimeError):
Raised when a function is not supported by the :epkg:`Array API`.
"""
- pass
-
class BaseArrayApi:
"""
diff --git a/onnx_array_api/npx/npx_core_api.py b/onnx_array_api/npx/npx_core_api.py
index d6688cf..a09280a 100644
--- a/onnx_array_api/npx/npx_core_api.py
+++ b/onnx_array_api/npx/npx_core_api.py
@@ -15,7 +15,7 @@
class args_tuple(tuple):
"""Overwrites a tuple to make the distinction later in the code."""
- pass
+ __slots__ = ()
def cst(*args, **kwargs):
@@ -140,7 +140,7 @@ def _xapi(fn: Callable, inline: bool):
# It has the same signature
def wrapper(*inputs, **kwargs):
- if any(map(lambda x: isinstance(x, EagerTensor), inputs)):
+ if any(isinstance(x, EagerTensor) for x in inputs):
tensor_class = None
for x in inputs:
if isinstance(x, EagerTensor):
diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py
index db29ca2..c6319f2 100644
--- a/onnx_array_api/npx/npx_functions.py
+++ b/onnx_array_api/npx/npx_functions.py
@@ -1,5 +1,4 @@
from typing import Tuple, Union
-import array_api_compat.numpy as np_array_api
import numpy as np
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
from onnx.helper import make_tensor, tensor_dtype_to_np_dtype
@@ -282,7 +281,8 @@ def astype(
to = DType(TensorProto.STRING)
else:
raise TypeError(f"dtype must of type DType, not {type(dtype)}-{dtype}.")
- return var(a, op="Cast", to=to.code)
+ return var(a, op="Cast", to=to.code)
+ return var(a, op="Cast", to=dtype.code)
@npxapi_inline
@@ -480,7 +480,7 @@ def eye(
/,
*,
k: ParType[int] = 0,
- dtype: ParType[DType] = DType(TensorProto.DOUBLE),
+ dtype: ParType[DType] = DType(TensorProto.DOUBLE), # noqa: B008
):
"See :func:`numpy.eye`."
shape = cst(np.array([-1], dtype=np.int64))
@@ -624,6 +624,8 @@ def isdtype(
See :epkg:`BaseArrayAPI:isdtype`.
This function is not converted into an onnx graph.
"""
+ import array_api_compat.numpy as np_array_api
+
if isinstance(dtype, DType):
dti = tensor_dtype_to_np_dtype(dtype.code)
return np_array_api.isdtype(dti, kind)
diff --git a/onnx_array_api/npx/npx_functions_test.py b/onnx_array_api/npx/npx_functions_test.py
index 4d442dd..3d03def 100644
--- a/onnx_array_api/npx/npx_functions_test.py
+++ b/onnx_array_api/npx/npx_functions_test.py
@@ -22,21 +22,21 @@
@npxapi_function
def _min_max(
- x: TensorType[ElemType.numerics, "T"]
+ x: TensorType[ElemType.numerics, "T"],
) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.numerics, "T"]]:
return tuple_var(var(x, op="ReduceMin"), var(x, op="ReduceMax"))
@npxapi_inline
def _min_max_inline(
- x: TensorType[ElemType.numerics, "T"]
+ x: TensorType[ElemType.numerics, "T"],
) -> TupleType[TensorType[ElemType.numerics, "T"], TensorType[ElemType.numerics, "T"]]:
return tuple_var(var(x, op="ReduceMin"), var(x, op="ReduceMax"))
@npxapi_function
def absolute(
- x: TensorType[ElemType.numerics, "T"]
+ x: TensorType[ElemType.numerics, "T"],
) -> TensorType[ElemType.numerics, "T"]:
"See :func:`numpy.absolute`."
return var(x, op="Abs")
@@ -90,7 +90,7 @@ def log1p(x: TensorType[ElemType.floats, "T"]) -> TensorType[ElemType.floats, "T
@npxapi_function
def negative(
- x: TensorType[ElemType.numerics, "T"]
+ x: TensorType[ElemType.numerics, "T"],
) -> TensorType[ElemType.numerics, "T"]:
"See :func:`numpy.negative`."
return var(x, op="Neg")
diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py
index 4496d79..91034f7 100644
--- a/onnx_array_api/npx/npx_graph_builder.py
+++ b/onnx_array_api/npx/npx_graph_builder.py
@@ -450,7 +450,7 @@ def _make_onnx(self):
name = inp.name
if name is None:
raise RuntimeError(
- f"Input {i} is None for function " f"{self.function_name!r}."
+ f"Input {i} is None for function {self.function_name!r}."
)
inputs.append(name)
@@ -473,7 +473,7 @@ def _make_onnx(self):
model = make_model(
graph,
opset_imports=opset_imports,
- functions=list(f[0] for f in self.functions_.values()),
+ functions=[f[0] for f in self.functions_.values()],
ir_version=self.ir_version,
)
if not is_windows() or not is_azure():
@@ -512,12 +512,7 @@ def _function_to_onnx(self, fct: Callable, n_inputs: int, n_outputs: int):
there is an undefined number of inputs
"""
sig = signature(fct)
- if any(
- map(
- lambda t: issubclass(t.annotation, SequenceType),
- sig.parameters.values(),
- )
- ):
+ if any(issubclass(t.annotation, SequenceType) for t in sig.parameters.values()):
# onnx does not allow undefined number of inputs
key = fct.__module__, fct.__name__, n_inputs
else:
@@ -852,7 +847,7 @@ def to_onnx(
node_inputs.append(input_name)
continue
- if isinstance(i, tuple) and all(map(lambda x: isinstance(x, int), i)):
+ if isinstance(i, tuple) and all(isinstance(x, int) for x in i):
ai = np.array(list(i), dtype=np.int64)
c = Cst(ai)
input_name = self._unique(var._prefix)
diff --git a/onnx_array_api/npx/npx_helper.py b/onnx_array_api/npx/npx_helper.py
index 34d9af3..b2c6b48 100644
--- a/onnx_array_api/npx/npx_helper.py
+++ b/onnx_array_api/npx/npx_helper.py
@@ -130,8 +130,7 @@ def iter_nodes(nodes: Sequence[NodeProto]) -> Iterator[NodeProto]:
and hasattr(att, "g")
and att.g is not None
):
- for n in iter_nodes(att.g.node):
- yield n
+ yield from iter_nodes(att.g.node)
def onnx_model_to_function(
diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py
index 20becbd..267eda5 100644
--- a/onnx_array_api/npx/npx_jit_eager.py
+++ b/onnx_array_api/npx/npx_jit_eager.py
@@ -167,7 +167,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, .
f"to the attribute list, v={v}."
)
res.append(v.key)
- elif isinstance(v, (int, float, bool, DType)):
+ elif isinstance(v, (int, float, bool, complex, DType)):
if iv in self.kwargs_to_input_:
res.append(self.kwargs_to_input_[iv])
res.append(type(v))
@@ -204,7 +204,7 @@ def make_key(self, *values: List[Any], **kwargs: Dict[str, Any]) -> Tuple[Any, .
if k in self.kwargs_to_input_:
res.append(type(v))
res.append(v)
- elif isinstance(v, (int, float, str, type, bool, DType)):
+ elif isinstance(v, (int, float, str, type, bool, complex, DType)):
res.append(k)
res.append(type(v))
res.append(v)
@@ -563,7 +563,7 @@ class JitOnnx(JitEager):
def __init__(
self,
f: Callable,
- tensor_class: type = None,
+ tensor_class: Optional[type] = None,
target_opsets: Optional[Dict[str, int]] = None,
output_types: Optional[Dict[Any, TensorType]] = None,
ir_version: Optional[int] = None,
@@ -636,7 +636,7 @@ class EagerOnnx(JitEager):
def __init__(
self,
f: Callable,
- tensor_class: type = None,
+ tensor_class: Optional[type] = None,
target_opsets: Optional[Dict[str, int]] = None,
output_types: Optional[Dict[Any, TensorType]] = None,
ir_version: Optional[int] = None,
@@ -671,12 +671,12 @@ def _preprocess_constants(self, *args):
new_args.append(self.tensor_class(n.inputs[0]))
modified = True
elif isinstance(n, tuple):
- if all(map(lambda x: isinstance(x, int), n)):
+ if all(isinstance(x, int) for x in n):
new_args.append(
self.tensor_class(np.array(list(n), dtype=np.int64))
)
modified = True
- elif any(map(lambda t: isinstance(t, Var), n)):
+ elif any(isinstance(t, Var) for t in n):
raise TypeError(
f"Unexpected types in tuple "
f"({[type(t) for t in n]}) for input {i}, "
@@ -727,14 +727,14 @@ def __call__(self, *args, already_eager=False, **kwargs):
)
if already_eager:
if any(
- map(
- lambda t: t is not None
+ (
+ t is not None
and not isinstance(
t,
EagerOnnx.allowed_input_types,
- ),
- args,
+ )
)
+ for t in args
):
raise TypeError(
f"One of the input is not an EagerTensor or a constant, "
@@ -759,8 +759,8 @@ def __call__(self, *args, already_eager=False, **kwargs):
try:
res = self.f(*values, **kwargs)
except (AttributeError, TypeError) as e:
- inp1 = ", ".join(map(str, map(lambda a: type(a).__name__, args)))
- inp2 = ", ".join(map(str, map(lambda a: type(a).__name__, values)))
+ inp1 = ", ".join(map(str, [type(a).__name__ for a in args]))
+ inp2 = ", ".join(map(str, [type(a).__name__ for a in values]))
raise TypeError(
f"Unexpected types, input types are args=[{inp1}], "
f"values=[{inp2}], kwargs={kwargs}. "
@@ -778,7 +778,7 @@ def __call__(self, *args, already_eager=False, **kwargs):
f"from module {self.f.__module__!r}, "
f"type of first input is {type(args[0])}."
)
- elif isinstance(res, Var) or any(map(lambda x: isinstance(x, Var), res)):
+ elif isinstance(res, Var) or any(isinstance(x, Var) for x in res):
# The function returns instance of type Var.
# It does not support eager mode and needs
# to be converted into onnx.
diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py
index 68a4da7..9579455 100644
--- a/onnx_array_api/npx/npx_numpy_tensors.py
+++ b/onnx_array_api/npx/npx_numpy_tensors.py
@@ -223,7 +223,8 @@ def __bool__(self):
if self.shape:
warnings.warn(
f"Conversion to bool only works for scalar, not for {self!r}, "
- f"bool(...)={bool(self._tensor)}."
+ f"bool(...)={bool(self._tensor)}.",
+ stacklevel=0,
)
try:
return bool(self._tensor)
@@ -264,6 +265,8 @@ def __float__(self):
DType(TensorProto.DOUBLE),
DType(TensorProto.FLOAT16),
DType(TensorProto.BFLOAT16),
+ DType(TensorProto.COMPLEX64),
+ DType(TensorProto.COMPLEX128),
}:
raise TypeError(
f"Conversion to float only works for float scalar, "
@@ -271,6 +274,26 @@ def __float__(self):
)
return float(self._tensor)
+ def __complex__(self):
+ "Implicit conversion to complex."
+ if self.shape:
+ raise ValueError(
+ f"Conversion to bool only works for scalar, not for {self!r}."
+ )
+ if self.dtype not in {
+ DType(TensorProto.FLOAT),
+ DType(TensorProto.DOUBLE),
+ DType(TensorProto.FLOAT16),
+ DType(TensorProto.BFLOAT16),
+ DType(TensorProto.COMPLEX64),
+ DType(TensorProto.COMPLEX128),
+ }:
+ raise TypeError(
+ f"Conversion to float only works for float scalar, "
+ f"not for dtype={self.dtype}."
+ )
+ return complex(self._tensor)
+
def __iter__(self):
"""
The :epkg:`Array API` does not define this function (2022/12).
@@ -279,7 +302,8 @@ def __iter__(self):
warnings.warn(
f"Iterators are not implemented in the generic case. "
f"Every function using them cannot be converted into ONNX "
- f"(tensors - {type(self)})."
+ f"(tensors - {type(self)}).",
+ stacklevel=0,
)
for row in self._tensor:
yield self.__class__(row)
@@ -289,5 +313,3 @@ class JitNumpyTensor(NumpyTensor, JitTensor):
"""
Defines a value for a specific backend.
"""
-
- pass
diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py
index 3e4faa7..40ebc12 100644
--- a/onnx_array_api/npx/npx_tensors.py
+++ b/onnx_array_api/npx/npx_tensors.py
@@ -10,8 +10,6 @@ class JitTensor:
Defines a value for a specific jit mode
"""
- pass
-
class EagerTensor(BaseArrayApi):
"""
@@ -93,7 +91,7 @@ def _astype_impl(
if not isinstance(x, Var):
raise TypeError(f"Input 0 must be a Var not {type(x)}.")
- meth = getattr(Var, "astype")
+ meth = getattr(Var, "astype") # noqa: B009
return meth(x, dtype)
@staticmethod
diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py
index 8284765..2f2a6a6 100644
--- a/onnx_array_api/npx/npx_types.py
+++ b/onnx_array_api/npx/npx_types.py
@@ -11,7 +11,7 @@ class WrapperType:
WrapperType.
"""
- pass
+ __slots__ = ()
class DType(WrapperType):
@@ -78,8 +78,8 @@ def __eq__(self, dt: "DType") -> bool:
return self.code_ == dt.dtype.code_
try:
dti = np_dtype_to_tensor_dtype(dt)
- except KeyError:
- raise TypeError(f"dt must be DType not {type(dt)} - {dt!r}.")
+ except KeyError as e:
+ raise TypeError(f"dt must be DType not {type(dt)} - {dt!r}.") from e
return self.code_ == dti
def __lt__(self, dt: "DType") -> bool:
@@ -90,8 +90,8 @@ def __lt__(self, dt: "DType") -> bool:
raise TypeError(f"dt must be DType not {type(dt)}.")
try:
dti = np_dtype_to_tensor_dtype(dt)
- except KeyError:
- raise TypeError(f"dt must be DType not {type(dt)} - {dt}.")
+ except KeyError as e:
+ raise TypeError(f"dt must be DType not {type(dt)} - {dt}.") from e
return self.code_ < dti
@classmethod
@@ -102,12 +102,10 @@ def type_name(cls) -> str:
class _DType2(DType):
"Wraps a type into a different type."
- pass
class _DTypes(DType):
"Wraps a type into a different type."
- pass
class ElemTypeCstInner(WrapperType):
@@ -367,7 +365,7 @@ def onnx_type(cls):
if cls.dtype == str:
return AttributeProto.STRING
raise RuntimeError(
- f"Unsupported attribute type {cls.dtype!r} " f"for parameter {cls!r}."
+ f"Unsupported attribute type {cls.dtype!r} for parameter {cls!r}."
)
@@ -403,9 +401,11 @@ class ShapeType(Tuple[int, ...]):
Defines a shape type.
"""
+ __slots__ = ()
+
@classmethod
def __class_getitem__(cls, *args):
- if any(map(lambda t: t is not None and not isinstance(t, (int, str)), args)):
+ if any((t is not None and not isinstance(t, (int, str))) for t in args):
raise TypeError(
f"Unexpected value for args={args}, every element should int or str."
)
@@ -504,7 +504,7 @@ def __class_getitem__(cls, *args):
if name:
msg.append(name)
if dtypes is not None:
- msg.append("_".join(map(lambda t: str(t.dtype), dtypes)))
+ msg.append("_".join(str(t.dtype) for t in dtypes))
if shape is not None:
msg.append("_".join(map(str, shape)))
final = "__".join(msg)
@@ -561,11 +561,11 @@ def _name_set(self):
s += 1 << dt.dtype
try:
return ElemType.set_names[s]
- except KeyError:
+ except KeyError as e:
raise RuntimeError(
f"Unable to guess element type name for {s}: "
f"{repr(self)} in {ElemType.set_names}."
- )
+ ) from e
@classmethod
def issuperset(cls, tensor_type: type) -> bool:
@@ -686,7 +686,7 @@ def len(cls):
@classmethod
def type_name(cls) -> str:
"Returns its full name."
- dts = ", ".join(map(lambda s: s.type_name(), cls.elem_types))
+ dts = ", ".join(s.type_name() for s in cls.elem_types)
if cls.name:
newt = f"TupleType[{dts}, {cls.name!r}]"
else:
diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py
index ca8af0d..0e71070 100644
--- a/onnx_array_api/npx/npx_var.py
+++ b/onnx_array_api/npx/npx_var.py
@@ -33,7 +33,7 @@ def __init__(
):
if not issubclass(dtype, ParType):
raise TypeError(
- f"dtype for parameter {name!r} must be of " f"ParType not {dtype}."
+ f"dtype for parameter {name!r} must be of ParType not {dtype}."
)
if parent_op is None:
raise ValueError(f"parent_op must be filled for paramenter {name!r}.")
@@ -453,7 +453,7 @@ def _get_vars(self):
deleted.append(var)
continue
raise TypeError(
- f"Unexpected type {type(applied)} as output of " f"function {fct}."
+ f"Unexpected type {type(applied)} as output of function {fct}."
)
vs.append(var)
for i in reversed(var.inputs):
@@ -469,11 +469,11 @@ def _get_vars(self):
replacement_cst[id(i)] = cst(np.array(i))
continue
if isinstance(i, tuple):
- if all(map(lambda x: isinstance(x, int), i)):
+ if all(isinstance(x, int) for x in i):
cst = Var.get_cst_var()[0]
replacement_cst[id(i)] = cst(np.array(list(i), dtype=np.int64))
continue
- if any(map(lambda t: isinstance(t, Var), i)):
+ if any(isinstance(t, Var) for t in i):
raise TypeError(
f"Unexpected types in tuple "
f"({[type(t) for t in i]}), "
@@ -1138,7 +1138,7 @@ class Input(Var):
:param annotation: annotation if any is available
"""
- def __init__(self, name: str = None, annotation: Optional[type] = None):
+ def __init__(self, name: Optional[str] = None, annotation: Optional[type] = None):
Var.__init__(self)
self.name = name
self._prefix = name or "I"
@@ -1171,16 +1171,20 @@ def __init__(self, cst: Any):
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
elif isinstance(cst, float):
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
+ elif isinstance(cst, complex):
+ Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity")
elif isinstance(cst, list):
- if all(map(lambda t: isinstance(t, bool), cst)):
+ if all(isinstance(t, bool) for t in cst):
Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity")
- elif all(map(lambda t: isinstance(t, (int, bool)), cst)):
+ elif all(isinstance(t, (int, bool)) for t in cst):
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
- elif all(map(lambda t: isinstance(t, (float, int, bool)), cst)):
+ elif all(isinstance(t, (float, int, bool)) for t in cst):
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
+ elif all(isinstance(t, (float, int, bool, complex)) for t in cst):
+ Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity")
else:
raise ValueError(
- f"Unable to convert cst (type={type(cst)}), " f"value={cst}."
+ f"Unable to convert cst (type={type(cst)}), value={cst}."
)
else:
raise NotImplementedError(
diff --git a/onnx_array_api/ort/ort_profile.py b/onnx_array_api/ort/ort_profile.py
index b61df67..ebccaba 100644
--- a/onnx_array_api/ort/ort_profile.py
+++ b/onnx_array_api/ort/ort_profile.py
@@ -52,7 +52,7 @@ def sep_event(s):
for c in agg_cols:
df[c] = df[c].fillna("")
df["dur"] = df["dur"].fillna(0)
- agg = df[agg_cols + ["dur"]].groupby(agg_cols).sum()
+ agg = df[[*agg_cols, "dur"]].groupby(agg_cols).sum()
return agg
@@ -101,14 +101,16 @@ def ort_profile(
if providers is None:
providers = ["CPUExecutionProvider"]
sess = InferenceSession(obj, sess_options, providers=providers, **kwargs)
- first = list(feeds.values())[0]
+ for v in feeds.values():
+ first = v
+ break
if isinstance(first, numpy.ndarray):
- for i in range(repeat):
+ for _i in range(repeat):
sess.run(None, feeds)
else:
out_names = [o.name for o in sess.get_outputs()]
- for i in range(repeat):
+ for _i in range(repeat):
sess._sess.run_with_ort_values(feeds, out_names, None)
prof = sess.end_profiling()
@@ -177,7 +179,7 @@ def _idx(row):
df[c] = df[c].apply(str)
df = df.copy()
df["count"] = 1
- gr = df[groupkey + ["dur", "count"]].groupby(groupkey)
+ gr = df[[*groupkey, "dur", "count"]].groupby(groupkey)
return gr.sum()
@@ -187,7 +189,9 @@ def _process_shape(s: Tuple[int, ...], keys: Dict[str, str]) -> str:
for v in value:
if len(v) != 1:
raise NotImplementedError(f"Unexpected value {v} in {s!r}.")
- k, v = list(v.items())[0]
+ for _k, _v in v.items():
+ k, v = _k, _v
+ break
n = "-".join([keys[k], "x".join(map(str, v))])
ns.append(n)
return ",".join(ns)
diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py
index 2117e3f..4f53e6e 100644
--- a/onnx_array_api/ort/ort_tensors.py
+++ b/onnx_array_api/ort/ort_tensors.py
@@ -86,7 +86,7 @@ def __init__(
tensor_class: type,
input_names: List[str],
onx: ModelProto,
- f: Callable = None,
+ f: Optional[Callable] = None,
):
try:
self.ref = InferenceSession(
@@ -282,5 +282,3 @@ class JitOrtTensor(OrtTensor, OrtCommon, JitTensor):
"""
Defines a value for :epkg:`onnxruntime` as a backend.
"""
-
- pass
diff --git a/onnx_array_api/plotting/_helper.py b/onnx_array_api/plotting/_helper.py
index 3131177..5c5d881 100644
--- a/onnx_array_api/plotting/_helper.py
+++ b/onnx_array_api/plotting/_helper.py
@@ -94,7 +94,7 @@ def _extract_attribute_value(
f"Unable to convert attribute {att.name!r} type {att.type!r}."
)
raise AttributeError( # pragma: no cover
- f"Unable to convert default value for {ref_att.name!r} " f"type {att.type!r}."
+ f"Unable to convert default value for {ref_att.name!r} type {att.type!r}."
)
@@ -120,7 +120,7 @@ def get_tensor_shape(obj):
for d in obj.tensor_type.shape.dim:
v = d.dim_value if d.dim_value > 0 else d.dim_param
shape.append(v)
- shape = None if not shape else list(None if s == 0 else s for s in shape)
+ shape = None if not shape else [None if s == 0 else s for s in shape]
return shape
@@ -183,7 +183,7 @@ def _get_shape(obj):
arr = to_array(obj)
return arr.shape
raise RuntimeError( # pragma: no cover
- f"Unable to guess type from {obj0!r}, " f"data_type is {obj.data_type!r}."
+ f"Unable to guess type from {obj0!r}, data_type is {obj.data_type!r}."
)
if hasattr(obj, "type"):
obj = obj.type
diff --git a/onnx_array_api/plotting/dot_plot.py b/onnx_array_api/plotting/dot_plot.py
index cff93f5..af8ad22 100644
--- a/onnx_array_api/plotting/dot_plot.py
+++ b/onnx_array_api/plotting/dot_plot.py
@@ -116,7 +116,12 @@ def myloss(x, y):
clean_label_reg2 = re.compile("\\\\p\\{[0-9P]{1,6}\\}")
def dot_name(text):
- return text.replace("/", "_").replace(":", "__").replace(".", "_")
+ return (
+ text.replace("/", "_")
+ .replace(":", "__")
+ .replace(".", "_")
+ .replace("-", "_")
+ )
def dot_label(text):
if text is None:
@@ -305,7 +310,7 @@ def dot_label(text):
exp.append(f' label="{node.op_type}\\n({dot_name(field)}){satts}";')
exp.append(f" fontsize={fontsize};")
exp.append(" color=black;")
- exp.append("\n".join(map(lambda s: " " + s, subgraph.split("\n"))))
+ exp.append("\n".join(f" {s}" for s in subgraph.split("\n")))
node0 = body.node[0]
connects.append(
diff --git a/onnx_array_api/plotting/graphviz_helper.py b/onnx_array_api/plotting/graphviz_helper.py
new file mode 100644
index 0000000..4aec5e4
--- /dev/null
+++ b/onnx_array_api/plotting/graphviz_helper.py
@@ -0,0 +1,244 @@
+import os
+import subprocess
+import sys
+import tempfile
+from typing import List, Optional, Tuple, Union
+import numpy as np
+from onnx import ModelProto
+
+
+def _find_in_PATH(prog: str) -> Optional[str]:
+ """
+ Looks into every path mentioned in ``%PATH%`` a specific file,
+ it raises an exception if not found.
+
+ :param prog: program to look for
+ :return: path
+ """
+ sep = ";" if sys.platform.startswith("win") else ":"
+ path = os.environ["PATH"]
+ for p in path.split(sep):
+ f = os.path.join(p, prog)
+ if os.path.exists(f):
+ return p
+ return None
+
+
+def _find_graphviz_dot(exc: bool = True) -> str:
+ """
+ Determines the path to graphviz (on Windows),
+ the function tests the existence of versions 34 to 45
+ assuming it was installed in a standard folder:
+ ``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``.
+
+ :param exc: raise exception of be silent
+ :return: path to dot
+ :raises FileNotFoundError: if graphviz not found
+ """
+ if sys.platform.startswith("win"):
+ version = list(range(34, 60))
+ version.extend([f"{v}.1" for v in version])
+ for v in version:
+ graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe"
+ if os.path.exists(graphviz_dot):
+ return graphviz_dot
+ extra = ["build/update_modules/Graphviz/bin"]
+ for ext in extra:
+ graphviz_dot = os.path.join(ext, "dot.exe")
+ if os.path.exists(graphviz_dot):
+ return graphviz_dot
+ p = _find_in_PATH("dot.exe")
+ if p is None:
+ if exc:
+ raise FileNotFoundError(
+ f"Unable to find graphviz, look into paths such as {graphviz_dot}."
+ )
+ return None
+ return os.path.join(p, "dot.exe")
+ # linux
+ return "dot"
+
+
+def _run_subprocess(
+ args: List[str],
+ cwd: Optional[str] = None,
+):
+ assert not isinstance(
+ args, str
+ ), "args should be a sequence of strings, not a string."
+
+ p = subprocess.Popen(
+ args,
+ cwd=cwd,
+ shell=False,
+ env=os.environ,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ raise_exception = False
+ output = ""
+ while True:
+ output = p.stdout.readline().decode(errors="ignore")
+ if output == "" and p.poll() is not None:
+ break
+ if output:
+ if (
+ "fatal error" in output
+ or "CMake Error" in output
+ or "gmake: ***" in output
+ or "): error C" in output
+ or ": error: " in output
+ ):
+ raise_exception = True
+ p.poll()
+ error = p.stderr.readline().decode(errors="ignore")
+ p.stdout.close()
+ if error and raise_exception:
+ raise RuntimeError(
+ f"An error was found in the output. The build is stopped."
+ f"\n{output}\n---\n{error}"
+ )
+ return output + "\n" + error
+
+
+def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str:
+ """
+ Run :epkg:`Graphviz`.
+
+ :param filename: filename which contains the graph definition
+ :param image: output image
+ :param engine: *dot* or *neato*
+ :return: output of graphviz
+ """
+ ext = os.path.splitext(image)[-1]
+ assert ext in {
+ ".png",
+ ".bmp",
+ ".fig",
+ ".gif",
+ ".ico",
+ ".jpg",
+ ".jpeg",
+ ".pdf",
+ ".ps",
+ ".svg",
+ ".vrml",
+ ".tif",
+ ".tiff",
+ ".wbmp",
+ }, f"Unexpected extension {ext!r} for {image!r}."
+ if sys.platform.startswith("win"):
+ bin_ = os.path.dirname(_find_graphviz_dot())
+ # if bin not in os.environ["PATH"]:
+ # os.environ["PATH"] = os.environ["PATH"] + ";" + bin
+ exe = os.path.join(bin_, engine)
+ else:
+ exe = engine
+ if os.path.exists(image):
+ os.remove(image)
+ cmd = [exe, f"-T{ext[1:]}", filename, "-o", image]
+ output = _run_subprocess(cmd)
+ assert os.path.exists(image), (
+ f"Unable to find {image!r}, command line is "
+ f"{' '.join(cmd)!r}, Graphviz failed due to\n{output}"
+ )
+ return output
+
+
+def draw_graph_graphviz(
+ dot: Union[str, ModelProto],
+ image: str,
+ engine: str = "dot",
+) -> str:
+ """
+ Draws a graph using :epkg:`Graphviz`.
+
+ :param dot: dot graph or ModelProto
+ :param image: output image, None, just returns the output
+ :param engine: *dot* or *neato*
+ :return: :epkg:`Graphviz` output or
+ the dot text if *image* is None
+
+ The function creates a temporary file to store the dot file if *image* is not None.
+ """
+ if isinstance(dot, ModelProto):
+ from .dot_plot import to_dot
+
+ sdot = to_dot(dot)
+ else:
+ sdot = dot
+ with tempfile.NamedTemporaryFile(delete=False) as fp:
+ fp.write(sdot.encode("utf-8"))
+ fp.close()
+
+ filename = fp.name
+ assert os.path.exists(
+ filename
+ ), f"File {filename!r} cannot be created to store the graph."
+ out = _run_graphviz(filename, image, engine=engine)
+ assert os.path.exists(
+ image
+ ), f"Graphviz failed with no reason, {image!r} not found, output is {out}."
+ os.remove(filename)
+ return out
+
+
+def plot_dot(
+ dot: Union[str, ModelProto],
+ ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821
+ engine: str = "dot",
+ figsize: Optional[Tuple[int, int]] = None,
+) -> "matplotlib.axis.Axis": # noqa: F821
+ """
+ Draws a dot graph into a matplotlib graph.
+
+ :param dot: dot graph or ModelProto
+ :param image: output image, None, just returns the output
+ :param engine: *dot* or *neato*
+ :param figsize: figsize of ax is None
+ :return: :epkg:`Graphviz` output or, the dot text if *image* is None
+
+ .. plot::
+
+ import matplotlib.pyplot as plt
+ import onnx.parser
+ from onnx_array_api.plotting.graphviz_helper import plot_dot
+
+ model = onnx.parser.parse_model(
+ '''
+
+ agraph (float[N] x) => (float[N] z) {
+ two = Constant ()
+ four = Add(two, two)
+ z = Mul(four, four)
+ }
+ ''')
+
+ ax = plot_dot(model)
+ ax.set_title("Dummy graph")
+ plt.show()
+ """
+ if ax is None:
+ import matplotlib.pyplot as plt
+
+ _, ax = plt.subplots(1, 1, figsize=figsize)
+ clean = True
+ else:
+ clean = False
+
+ from PIL import Image
+
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp:
+ fp.close()
+
+ draw_graph_graphviz(dot, fp.name, engine=engine)
+ img = np.asarray(Image.open(fp.name))
+ os.remove(fp.name)
+
+ ax.imshow(img)
+
+ if clean:
+ ax.get_xaxis().set_visible(False)
+ ax.get_yaxis().set_visible(False)
+ ax.get_figure().tight_layout()
+ return ax
diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py
index 36f9feb..0b4d30a 100644
--- a/onnx_array_api/plotting/text_plot.py
+++ b/onnx_array_api/plotting/text_plot.py
@@ -85,10 +85,8 @@ def process_node(self):
)
else:
ts = " ".join(
- map(
- lambda t: f"{t['target_id']}:{_number2str(t['weight'])}",
- self.targets,
- )
+ f"{t['target_id']}:{_number2str(t['weight'])}"
+ for t in self.targets
)
text = f"{self.true_false}f {ts}"
else:
@@ -184,9 +182,7 @@ def iterate(nodes, node, depth=0, true_false=""):
rows.extend(r)
return "\n".join(rows)
- raise NotImplementedError( # pragma: no cover
- f"Type {node.op_type!r} cannot be displayed."
- )
+ raise NotImplementedError(f"Type {node.op_type!r} cannot be displayed.")
def _append_succ_pred(
@@ -353,7 +349,7 @@ def __init__(self, nodes):
def _find_sequence(node_name, known, done):
inputs = dnodes[node_name].input
- if any(map(lambda i: i not in known, inputs)):
+ if any((i not in known) for i in inputs):
return []
res = [node_name]
@@ -364,7 +360,7 @@ def _find_sequence(node_name, known, done):
if len(next_names) == 1:
next_name = next_names.pop()
inputs = dnodes[next_name].input
- if any(map(lambda i: i not in known, inputs)):
+ if any((i not in known) for i in inputs):
break
res.extend(next_name)
else:
@@ -392,7 +388,7 @@ def _find_sequence(node_name, known, done):
possibles[k] = v
sequences = OrderedDict()
- for k, v in possibles.items():
+ for k, _v in possibles.items():
if k in done:
continue
sequences[k] = _find_sequence(k, known, done)
@@ -403,7 +399,7 @@ def _find_sequence(node_name, known, done):
)
if not sequences:
- raise RuntimeError( # pragma: no cover
+ raise RuntimeError(
"Unexpected empty sequence (len(possibles)=%d, "
"len(done)=%d, len(nodes)=%d). This is usually due to "
"a name used both as result name and node node. "
@@ -434,7 +430,7 @@ def _find_sequence(node_name, known, done):
best = k
if best is None:
- raise RuntimeError( # pragma: no cover
+ raise RuntimeError(
f"Wrong implementation (len(sequence)={len(sequences)})."
)
if verbose:
@@ -453,7 +449,7 @@ def _find_sequence(node_name, known, done):
known |= set(v.output)
if len(new_nodes) != len(nodes):
- raise RuntimeError( # pragma: no cover
+ raise RuntimeError(
"The returned new nodes are different. "
"len(nodes=%d) != %d=len(new_nodes). done=\n%r"
"\n%s\n----------\n%s"
@@ -486,7 +482,7 @@ def _find_sequence(node_name, known, done):
n0s = set(n.name for n in nodes)
n1s = set(n.name for n in new_nodes)
if n0s != n1s:
- raise RuntimeError( # pragma: no cover
+ raise RuntimeError(
"The returned new nodes are different.\n"
"%r !=\n%r\ndone=\n%r"
"\n----------\n%s\n----------\n%s"
@@ -758,7 +754,7 @@ def str_node(indent, node):
try:
val = str(to_array(att.t).tolist())
except TypeError as e:
- raise TypeError( # pragma: no cover
+ raise TypeError(
"Unable to display tensor type %r.\n%s"
% (att.type, str(att))
) from e
@@ -828,7 +824,10 @@ def str_node(indent, node):
rows.append(f"opset: domain={opset.domain!r} version={opset.version!r}")
if hasattr(model, "graph"):
if model.doc_string:
- rows.append(f"doc_string: {model.doc_string}")
+ if len(model.doc_string) < 55:
+ rows.append(f"doc_string: {model.doc_string}")
+ else:
+ rows.append(f"doc_string: {model.doc_string[:55]}...")
main_model = model
model = model.graph
else:
@@ -853,9 +852,7 @@ def str_node(indent, node):
if isinstance(att, str):
rows.append(f"attribute: {att!r}")
else:
- raise NotImplementedError( # pragma: no cover
- "Not yet introduced in onnx."
- )
+ raise NotImplementedError("Not yet introduced in onnx.")
# initializer
if hasattr(model, "initializer"):
@@ -867,9 +864,16 @@ def str_node(indent, node):
else:
content = ""
line_name_new[init.name] = len(rows)
+ if init.doc_string:
+ t = (
+ f"init: name={init.name!r} type={_get_type(init)} "
+ f"shape={_get_shape(init)}{content}"
+ )
+ rows.append(f"{t}{' ' * max(0, 70 - len(t))}-- {init.doc_string}")
+ continue
rows.append(
- "init: name=%r type=%r shape=%r%s"
- % (init.name, _get_type(init), _get_shape(init), content)
+ f"init: name={init.name!r} type={_get_type(init)} "
+ f"shape={_get_shape(init)}{content}"
)
if level == 0:
rows.append("----- main graph ----")
@@ -894,7 +898,7 @@ def str_node(indent, node):
try:
nodes = reorder_nodes_for_display(model.node, verbose=verbose)
- except RuntimeError as e: # pragma: no cover
+ except RuntimeError as e:
if raise_exc:
raise e
else:
@@ -924,9 +928,7 @@ def str_node(indent, node):
indent = mi
if previous_indent is not None and indent < previous_indent:
if verbose:
- print( # pragma: no cover
- f"[onnx_simple_text_plot] break2 {node.op_type}"
- )
+ print(f"[onnx_simple_text_plot] break2 {node.op_type}")
add_break = True
if not add_break and previous_out is not None:
if not (set(node.input) & previous_out):
@@ -947,7 +949,7 @@ def str_node(indent, node):
rows.append(str_node(indent if use_indentation else 0, node))
indents[name] = indent
- for i, o in enumerate(node.output):
+ for _i, o in enumerate(node.output):
indents[o] = indent + 1
previous_indent = indents[name]
@@ -1052,7 +1054,10 @@ def _mark_link(rows, lengths, r1, r2, d):
for fct in main_model.functions:
rows.append(f"----- function name={fct.name} domain={fct.domain}")
if fct.doc_string:
- rows.append(f"----- doc_string: {fct.doc_string}")
+ if len(fct.doc_string) < 55:
+ rows.append(f"----- doc_string: {fct.doc_string}")
+ else:
+ rows.append(f"----- doc_string: {fct.doc_string[:55]}...")
res = onnx_simple_text_plot(
fct,
verbose=verbose,
@@ -1111,10 +1116,19 @@ def onnx_text_plot_io(model, verbose=False, att_display=None):
)
# initializer
for init in model.initializer:
+
+ if init.doc_string:
+ t = (
+ f"init: name={init.name!r} type={_get_type(init)} "
+ f"shape={_get_shape(init)}"
+ )
+ rows.append(f"{t}{' ' * max(0, 70 - len(t))}-- {init.doc_string}")
+ continue
rows.append(
- "init: name=%r type=%r shape=%r"
- % (init.name, _get_type(init), _get_shape(init))
+ f"init: name={init.name!r} type={_get_type(init)} "
+ f"shape={_get_shape(init)}"
)
+
# outputs
for out in model.output:
rows.append(
diff --git a/onnx_array_api/profiling.py b/onnx_array_api/profiling.py
index 52c464a..ab2cc6b 100644
--- a/onnx_array_api/profiling.py
+++ b/onnx_array_api/profiling.py
@@ -73,8 +73,8 @@ def _get_root(node, stor=None):
stor.append(node)
if not node.called_by:
return node
- if len(node.called_by) == 1:
- return _get_root(node.called_by[0], stor=stor)
+ if len(node.called_by) == 0:
+ return None
res = None
for ct in node.called_by:
k = id(node), id(ct)
@@ -247,8 +247,7 @@ def depth_first(node, roots_keys, indent=0):
else:
if filter_node is not None and not filter_node(n):
continue
- for t in depth_first(n, roots_keys, indent + 1):
- yield t
+ yield from depth_first(n, roots_keys, indent + 1)
if filter_node is None:
filter_node = ProfileNode.filter_node_
@@ -472,7 +471,7 @@ def add_rows(rows, d):
def profile2df(
ps: Stats,
as_df: bool = True,
- clean_text: bool = None,
+ clean_text: Optional[bool] = None,
verbose: bool = False,
fLOG=None,
):
@@ -740,7 +739,7 @@ def fct4():
node.add_called_by(child)
child.add_calls_to(node, vv)
- for k, v in nodes.items():
+ for _k, v in nodes.items():
root = v.get_root()
break
diff --git a/onnx_array_api/reference/__init__.py b/onnx_array_api/reference/__init__.py
index d8c5aa5..fd1d27c 100644
--- a/onnx_array_api/reference/__init__.py
+++ b/onnx_array_api/reference/__init__.py
@@ -11,6 +11,13 @@
)
from onnx.reference.op_run import to_array_extended
from .evaluator import ExtendedReferenceEvaluator
+from .evaluator_yield import (
+ DistanceExecution,
+ ResultExecution,
+ ResultType,
+ YieldEvaluator,
+ compare_onnx_execution,
+)
def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto:
diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py
index e20be76..89b5a84 100644
--- a/onnx_array_api/reference/evaluator.py
+++ b/onnx_array_api/reference/evaluator.py
@@ -7,6 +7,10 @@
from .ops.op_cast_like import CastLike_15, CastLike_19
from .ops.op_concat import Concat
from .ops.op_constant_of_shape import ConstantOfShape
+from .ops.op_fused_matmul import FusedMatMul
+from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
+from .ops.op_quick_gelu import QuickGelu
+from .ops.op_scatter_elements import ScatterElements
logger = getLogger("onnx-array-api-eval")
@@ -32,6 +36,11 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
CastLike_15,
CastLike_19,
ConstantOfShape,
+ FusedMatMul,
+ MemcpyFromHost,
+ MemcpyToHost,
+ QuickGelu,
+ ScatterElements,
]
@staticmethod
@@ -108,4 +117,7 @@ def run(self, *args, **kwargs):
"""
See :meth:`onnx.reference.ReferenceEvaluator.run`.
"""
+ if len(args) == 1 and isinstance(args[0], list):
+ feeds = dict(zip(self.input_names, args[0]))
+ return self.run(None, feeds, **kwargs)
return ReferenceEvaluator.run(self, *args, **kwargs)
diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py
new file mode 100644
index 0000000..b53c27d
--- /dev/null
+++ b/onnx_array_api/reference/evaluator_yield.py
@@ -0,0 +1,680 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Iterator, Optional, Tuple, Union
+from enum import IntEnum
+import numpy as np
+from onnx import ModelProto, TensorProto, ValueInfoProto, load
+from onnx.reference import ReferenceEvaluator
+from onnx.helper import tensor_dtype_to_np_dtype
+from onnx.shape_inference import infer_shapes
+from . import to_array_extended
+from .evaluator import ExtendedReferenceEvaluator
+
+
+def _align(res: str, limit: int) -> str:
+ if len(res) == limit:
+ return res
+ if len(res) > limit:
+ return res[:limit]
+ return res + " " * (limit - len(res))
+
+
+class ResultType(IntEnum):
+ RESULT = 1
+ INITIALIZER = 2
+ SPARSE_INITIALIZER = 4
+ INPUT = 8
+ OUTPUT = 16
+ NODE = 32
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}.{self._name_}"
+
+
+def _dimension_to_str(d):
+ if isinstance(d, int):
+ return str(d)
+ try:
+ int(d)
+ except ValueError:
+ return d
+ return f"{d!r}"
+
+
+def _rank_to_str(shape):
+ if shape:
+ return f"{len(shape)}:"
+ return " "
+
+
+@dataclass
+class ResultExecution:
+ """
+ The description of a result.
+ """
+
+ kind: ResultType
+ dtype: object
+ shape: tuple
+ summary: str
+ op_type: str
+ name: str
+ value: Optional[Any] = None
+
+ def __len__(self) -> int:
+ return 6
+
+ def __getitem__(self, i: int) -> Any:
+ if i == 0:
+ return self.kind
+ if i == 1:
+ return self.dtype
+ if i == 2:
+ return self.shape
+ if i == 3:
+ return self.summary
+ if i == 4:
+ return self.op_type
+ if i == 5:
+ return self.name
+ raise IndexError(f"i={i} out of boundary")
+
+ def __str__(self):
+ dtype = self.dtype if self.dtype != 0 else ""
+ els = [
+ _align(self.kind._name_, 6),
+ _align(str(dtype).replace("dtype(", "").replace(")", ""), 8),
+ _rank_to_str(self.shape)
+ + _align(
+ "x".join(
+ "" if self.shape is None else map(_dimension_to_str, self.shape)
+ ),
+ 18,
+ ),
+ self.summary,
+ _align(self.op_type or "", 15),
+ self.name or "",
+ ]
+ return " ".join(els)
+
+
+def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
+ """
+ Create a short string summarizing the value (discretization).
+
+ :param value: array
+ :param length: number of value to produce
+ :param module: discretization parameter
+ :return: short string
+ """
+ if isinstance(value, np.float32):
+ # This should not happen.
+ value = np.array(value)
+ assert isinstance(
+ value, np.ndarray
+ ), f"Unexpected type {type(value)} for value, it must be a numpy array."
+ value4 = np.zeros(length, dtype=np.float64)
+ if value.size <= length:
+ value4[: value.size] = value.flatten().astype(np.float64)
+ else:
+ if value.size % length != 0:
+ value2 = np.zeros(
+ value.size + length - value.size % length, dtype=np.float64
+ )
+ value2[: value.size] = value.flatten().astype(np.float64)
+ else:
+ value2 = value.flatten().astype(np.float64)
+ value4 = value2.reshape((4, -1)).sum(axis=1)
+ value4 = np.where(np.abs(value4) < 1e10, value4, np.nan)
+ s = []
+ for v in value4:
+ s.append("?" if np.isnan(v) else (chr(65 + int(v) % modulo)))
+ return "".join(s)
+
+
+class YieldEvaluator:
+ """
+ This class implements method `enumerate_results` which iterates on
+ intermediates results. By default, it uses
+ :class:`onnx_array_api.reference.ExtendedReferenceEvaluator`.
+
+ :param onnx_model: model to run
+ :param recursive: dig into subgraph and functions as well
+ :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator
+ `
+ """
+
+ def __init__(
+ self,
+ onnx_model: ModelProto,
+ recursive: bool = False,
+ cls: Optional[type[ExtendedReferenceEvaluator]] = None,
+ ):
+ assert not recursive, "recursive=True is not yet implemented"
+ self.onnx_model = onnx_model
+ self.evaluator = (
+ cls(onnx_model)
+ if cls is not None
+ else ExtendedReferenceEvaluator(onnx_model)
+ )
+
+ def enumerate_results(
+ self,
+ output_names: Optional[List[str]] = None,
+ feed_inputs: Optional[Dict[str, Any]] = None,
+ raise_exc: bool = True,
+ ) -> Iterator[Tuple[ResultType, str, Any]]:
+ """
+ Executes the onnx model and enumerate all the intermediate results.
+
+ Args:
+ output_names: requested outputs by names, None for all
+ feed_inputs: dictionary `{ input name: input value }`
+
+ Returns:
+ iterator on tuple(result kind, name, value, node.op_type or None)
+ """
+ assert isinstance(self.evaluator, ReferenceEvaluator), (
+ f"This implementation only works with "
+ f"ReferenceEvaluator not {type(self.evaluator)}"
+ )
+ attributes = {}
+ if output_names is None:
+ output_names = self.evaluator.output_names
+
+ results = {"": None}
+ results.update(self.evaluator.rt_inits_)
+ results.update(feed_inputs)
+ # step 0: initializer
+ for k, v in self.evaluator.rt_inits_.items():
+ yield ResultType.INITIALIZER, k, v, None
+ # step 1: inputs
+ for k, v in feed_inputs.items():
+ yield ResultType.INPUT, k, v, None
+
+ # step 2: execute nodes
+ yield_output = True
+ for node in self.evaluator.rt_nodes_:
+ for i in node.input:
+ if i not in results:
+ raise RuntimeError(
+ f"Unable to find input {i!r} "
+ f"in known results {sorted(results)}, "
+ f"self.rt_inits_ has {sorted(self.evaluator.rt_inits_)}, "
+ f"feed_inputs has {sorted(feed_inputs)}."
+ )
+ inputs = [results[i] for i in node.input]
+ linked_attributes = {}
+ if node.has_linked_attribute and attributes:
+ linked_attributes["linked_attributes"] = attributes
+
+ try:
+ if node.need_context():
+ outputs = node.run(*inputs, context=results, **linked_attributes)
+ else:
+ outputs = node.run(*inputs, **linked_attributes)
+ except Exception:
+ if raise_exc:
+ # ExtendedReferenceEvaluator(self.onnx_model, verbose=10).run(
+ # None, feed_inputs
+ # )
+ raise
+ yield_output = False
+ break
+
+ for name, value in zip(node.output, outputs):
+ yield ResultType.RESULT, name, value, node.op_type
+ results[name] = value
+
+ # step 3: outputs
+ if yield_output:
+ for name in output_names:
+ if name not in results:
+ raise RuntimeError(
+ f"Unable to find output name {name!r} in {sorted(results)}, "
+ f"proto is\n{self.proto_}"
+ )
+ yield ResultType.OUTPUT, name, results[name], None
+
+ def enumerate_summarized(
+ self,
+ output_names: Optional[List[str]] = None,
+ feed_inputs: Optional[Dict[str, Any]] = None,
+ raise_exc: bool = True,
+ keep_tensor: bool = False,
+ ) -> Iterator[ResultExecution]:
+ """
+ Executes the onnx model and enumerate intermediate results without their names.
+
+ :param output_names: requested outputs by names, None for all
+ :param feed_inputs: dictionary ``{ input name: input value }``
+ :param raise_exc: raises an exception if the execution fails or stop where it is
+ :param keep_tensor: keep the tensor in order to compute precise distances
+ :return: iterator on ResultExecution
+ """
+ for kind, name, value, op_type in self.enumerate_results(
+ output_names, feed_inputs, raise_exc=raise_exc
+ ):
+ summary = make_summary(value)
+ yield ResultExecution(
+ kind,
+ value.dtype,
+ value.shape,
+ summary,
+ op_type,
+ name,
+ value=value if keep_tensor else None,
+ )
+
+
+def discrepancies(
+ expected: np.ndarray, value: np.ndarray, eps: float = 1e-7
+) -> Dict[str, float]:
+ """
+ Computes absolute error and relative error between two matrices.
+ """
+ assert (
+ expected.size == value.size
+ ), f"Incompatible shapes v1.shape={expected.shape}, v2.shape={value.shape}"
+ expected = expected.ravel().astype(np.float32)
+ value = value.ravel().astype(np.float32)
+ diff = np.abs(expected - value)
+ rel = diff / (np.abs(expected) + eps)
+ return dict(aerr=float(diff.max()), rerr=float(rel.max()))
+
+
+class DistanceExecution:
+ """
+ Computes a distance between two results.
+ """
+
+ float_types = {
+ np.float16,
+ np.float32,
+ np.float64,
+ np.dtype("float16"),
+ np.dtype("float32"),
+ np.dtype("float64"),
+ }
+
+ def __init__(self, max_lag: int = 50):
+ self.kind_cost = 1000
+ self.type_cost = 10
+ self.rank_cost = 100
+ self.op_type_cost = 10
+ self.max_lag = max_lag
+ self.insert_cost = 1000
+
+ def distance_pair(self, r1: ResultExecution, r2: ResultExecution) -> float:
+ """
+ (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs"),
+
+ :param r1: first result
+ :param r2: second result
+ :return: distance
+ """
+ d = 0
+ if r1[0] != r2[0]:
+ # difference type
+ d += self.kind_cost
+ if r1[1] != r2[1]:
+ d += self._cost_type(r1[1], r2[1]) * self.type_cost
+ if r1[2] != r2[2]:
+ d += self._cost_shape(r1[2], r2[2])
+ if r1[3] != r2[3]:
+ d += self._cost_summary(r1[3], r2[3])
+ if r1[4] != r2[4]:
+ d += self.op_type_cost
+ return d
+
+ def _cost_type(self, t1: "np.dtype", t2: "np.dtype") -> float:
+ if t1 in self.float_types and t2 in self.float_types:
+ return 0.2
+ return 1
+
+ def _cost_shape(self, s1: Tuple[int, ...], s2: Tuple[int, ...]) -> float:
+ if s1 is None or s2 is None:
+ return self.rank_cost
+ if any(isinstance(s, str) for s in s1) or any(isinstance(s, str) for s in s2):
+ # dynamic shapes
+ if len(s1) != len(s2):
+ return self.rank_cost
+ d = 0
+ for i, j in zip(s1, s2):
+ if isinstance(i, int) and isinstance(j, int):
+ d += abs(i - j)
+ elif i != j:
+ d += self.rank_cost / 2
+ return d
+
+ d = abs(np.prod(s1) - np.prod(s2))
+ if len(s1) != len(s2):
+ return self.rank_cost + d
+ for i, j in zip(s1, s2):
+ d += abs(i - j)
+ return d
+
+ def _cost_summary(self, s1: str, s2: str) -> float:
+ if len(s1) != len(s2):
+ return 1e6
+ d = 0
+ for a, b in zip(s1, s2):
+ d += abs(ord(a) - ord(b))
+ return d
+
+ def distance_sequence(
+ self, s1: List[ResultExecution], s2: List[ResultExecution]
+ ) -> Tuple[float, List[Tuple[int, int]]]:
+ """
+ Computes the distance between two sequences of results.
+
+ :param s1: first sequence
+ :param s2: second sequence
+ :return: distance and alignment
+ """
+ delay = max(self.max_lag, abs(len(s2) - len(s1)) + 1)
+ distance = {(-1, -1): 0}
+ predecessor = {(-1, -1): None}
+ for i in range(len(s1)):
+ for j in range(max(0, i - delay), min(len(s2), i + delay)):
+ best = distance.get((i, j), 1e100)
+ pred = None
+ ki, kj = i - 1, j - 1
+ if (ki, kj) in distance:
+ d = distance[ki, kj] + self.distance_pair(s1[i], s2[j])
+ if d < best:
+ best = d
+ pred = (ki, kj)
+ ki, kj = i - 1, j
+ if (ki, kj) in distance:
+ d = distance[ki, kj] + self.insert_cost
+ if d < best:
+ best = d
+ pred = (ki, kj)
+ ki, kj = i, j - 1
+ if (ki, kj) in distance:
+ d = distance[ki, kj] + self.insert_cost
+ if d < best:
+ best = d
+ pred = (ki, kj)
+ distance[i, j] = best
+ predecessor[i, j] = pred
+
+ # reverse
+ way = []
+ last = len(s1) - 1, len(s2) - 1
+ while last is not None:
+ way.append(last)
+ last = predecessor[last]
+ return distance[len(s1) - 1, len(s2) - 1], list(reversed(way))[1:]
+
+ def to_str(
+ self,
+ s1: List[ResultExecution],
+ s2: List[ResultExecution],
+ alignment: List[Tuple[int, int]],
+ column_size: int = 60,
+ ) -> str:
+ """
+ Prints out the alignment between two sequences into a string.
+ :param s1: first sequence
+ :param s2: second sequence
+ :param alignment: alignment
+ :param column_size: column size
+ :return: test
+ """
+ rows = []
+ last = -1, -1
+ row_index = 1
+ for i, j in alignment:
+ assert i < len(s1), f"Unexpected value i={i} >= len(s1)={len(s1)}"
+ assert j < len(s2), f"Unexpected value i={j} >= len(s2)={len(s2)}"
+ expected = last[0] + 1, last[1] + 1
+
+ if expected == (i, j):
+ d1 = s1[i]
+ d2 = s2[j]
+ d = self.distance_pair(d1, d2)
+ symbol = "=" if d == 0 else "~"
+ line = (
+ f"{symbol} | {_align(str(d1), column_size)} | "
+ f"{_align(str(d2), column_size)}"
+ )
+ if (
+ d1.value is not None
+ and d2.value is not None
+ and d1.value.size == d2.value.size
+ ):
+ disc = discrepancies(d1.value, d2.value)
+ a, r = disc["aerr"], disc["rerr"]
+ line += f" | a={a:.5g} r={r:.5g}"
+ elif i == last[0]:
+ d2 = s2[j]
+ line = (
+ f"+ | {_align('', column_size)} | {_align(str(d2), column_size)} "
+ )
+ else:
+ d1 = s1[i]
+ line = f"- | {_align(str(d1), column_size)} | {_align('', column_size)}"
+ rows.append(f"{row_index:03d} {line}")
+ last = i, j
+ row_index += 1
+ return "\n".join(rows)
+
+
+def generate_input(info: ValueInfoProto) -> np.ndarray:
+ """
+ Generates one input.
+ """
+ elem_type = info.type.tensor_type.elem_type
+ shape = [
+ (getattr(d, "dim_value", None) or getattr(d, "dim_param")) # noqa: B009
+ for d in info.type.tensor_type.shape.dim
+ ]
+ new_shape = []
+ for sh in shape:
+ if isinstance(sh, str):
+ if len(new_shape) == 0:
+ new_shape.append(1)
+ else:
+ new_shape.append(16)
+ else:
+ new_shape.append(sh)
+ new_shape = tuple(new_shape)
+ p = np.prod(new_shape)
+ value = np.arange(p)
+ if elem_type == TensorProto.INT32:
+ return value.astype(np.int32).reshape(new_shape)
+ if elem_type == TensorProto.INT64:
+ return value.astype(np.int64).reshape(new_shape)
+ if elem_type == TensorProto.FLOAT:
+ return (value.astype(np.float32) / p).astype(np.float32).reshape(new_shape)
+ if elem_type == TensorProto.FLOAT16:
+ return (value.astype(np.float16) / p).astype(np.float16).reshape(new_shape)
+ if elem_type == TensorProto.DOUBLE:
+ return (value.astype(np.float64) / p).astype(np.float64).reshape(new_shape)
+ if elem_type == TensorProto.COMPLEX64:
+ return (value.astype(np.complex64) / p).astype(np.complex64).reshape(new_shape)
+ if elem_type == TensorProto.COMPLEX128:
+ return (
+ (value.astype(np.complex128) / p).astype(np.complex128).reshape(new_shape)
+ )
+ raise RuntimeError(f"Unexpected element_type {elem_type} for info={info}")
+
+
+def generate_inputs(model: ModelProto) -> List[np.ndarray]:
+ """
+ Generates inputs for a specific model.
+
+ :param model: ModelProto
+ :return: list of inputs
+ """
+ inputs = []
+ inits = set(i.name for i in model.graph.initializer)
+ for inp in model.graph.input:
+ if inp.name in inits:
+ break
+ inputs.append(generate_input(inp))
+ return inputs
+
+
+def _update_shape_types_with_proto(
+ proto: ModelProto,
+) -> Dict[str, Tuple[int, Tuple[Union[int, str], ...]]]:
+ """
+ Retrieves the shapes and types for a model.
+ """
+ assert isinstance(proto, ModelProto), f"Unexpected type {type(proto)} for proto"
+ res = {}
+
+ for val in proto.graph.input:
+ itype = val.type.tensor_type.elem_type
+ shape = tuple(
+ d.dim_param if d.dim_param else d.dim_value
+ for d in val.type.tensor_type.shape.dim
+ )
+ res[val.name] = [itype, shape]
+
+ for val in proto.graph.output:
+ itype = val.type.tensor_type.elem_type
+ shape = tuple(
+ d.dim_param if d.dim_param else d.dim_value
+ for d in val.type.tensor_type.shape.dim
+ )
+ res[val.name] = [itype, shape]
+
+ for val in proto.graph.initializer:
+ itype = val.data_type
+ shape = tuple(d for d in val.dims)
+ res[val.name] = [itype, shape]
+
+ new_proto = infer_shapes(proto)
+ for val in new_proto.graph.value_info:
+ itype = val.type.tensor_type.elem_type
+ shape = tuple(
+ d.dim_param if d.dim_param else d.dim_value
+ for d in val.type.tensor_type.shape.dim
+ )
+ res[val.name] = [itype, shape]
+
+ return res
+
+
+def _enumerate_result_no_execution(model: ModelProto) -> Iterator[ResultType]:
+ """
+ Produces a list of results based on a model in order to
+ trigger the edit distance comparison.
+ """
+ type_shape = _update_shape_types_with_proto(model)
+ for i in model.graph.initializer:
+ itype, shape = type_shape.get(i.name, (0, None))
+ dtype = tensor_dtype_to_np_dtype(itype)
+ yield ResultExecution(
+ ResultType.INITIALIZER,
+ dtype,
+ shape,
+ make_summary(to_array_extended(i)),
+ "INIT",
+ i.name,
+ )
+ for i in model.graph.input:
+ itype, shape = type_shape.get(i.name, (0, None))
+ dtype = tensor_dtype_to_np_dtype(itype)
+ yield ResultExecution(ResultType.INPUT, dtype, shape, "????", "INPUT", i.name)
+ for node in model.graph.node:
+ yield ResultExecution(ResultType.NODE, 0, None, "????", node.op_type, node.name)
+ for o in node.output:
+ itype, shape = type_shape.get(o, (0, None))
+ dtype = 0 if itype == 0 else tensor_dtype_to_np_dtype(itype)
+ yield ResultExecution(
+ ResultType.RESULT, dtype, shape, "????", node.op_type, o
+ )
+ for i in model.graph.output:
+ itype, shape = type_shape.get(i.name, (0, None))
+ dtype = tensor_dtype_to_np_dtype(itype)
+ yield ResultExecution(ResultType.OUTPUT, dtype, shape, "????", "OUTPUT", i.name)
+
+
+def compare_onnx_execution(
+ model1: ModelProto,
+ model2: ModelProto,
+ inputs: Optional[Union[List[Any], Tuple[Dict[str, Any]]]] = None,
+ verbose: int = 0,
+ raise_exc: bool = True,
+ mode: str = "execute",
+ keep_tensor: bool = False,
+ cls: Optional[type[ReferenceEvaluator]] = None,
+) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
+ """
+ Compares the execution of two onnx models.
+ The function assumes both models takes the same inputs.
+ See :ref:`l-onnx-diff-example` to see a full example using
+ this function.
+
+ :param model1: first model
+ :param model2: second model
+ :param inputs: inputs to use, a list of inputs if both models have
+ the same number of inputs or two dictionaries, one for each model
+ :param verbose: verbosity
+ :param raise_exc: raise exception if the execution fails or stop at the error
+ :param mode: the model should be executed but the function can be executed
+ but the comparison may append on nodes only
+ :param keep_tensor: keeps the tensor in order to compute a precise distance
+ :param cls: evaluator class to use
+ :return: four results, a sequence of results
+ for the first model and the second model,
+ the alignment between the two, DistanceExecution
+ """
+ assert mode in {"execute", "nodes"}, f"Unexpected value for mode={mode!r}."
+
+ if mode == "execute":
+ if inputs is None:
+ if verbose:
+ print("[compare_onnx_execution] generate inputs")
+ inputs = generate_inputs(model1)
+ if isinstance(inputs, tuple):
+ assert len(inputs) == 2, f"Unexpected number {len(inputs)} of inputs."
+ feeds1, feeds2 = inputs
+ else:
+ feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)}
+ feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)}
+ assert isinstance(feeds1, dict), f"Unexpected type {type(feeds1)} for inputs"
+ assert isinstance(feeds2, dict), f"Unexpected type {type(feeds2)} for inputs"
+ if verbose:
+ print(f"[compare_onnx_execution] execute with {len(inputs)} inputs")
+ print("[compare_onnx_execution] execute first model")
+ res1 = list(
+ YieldEvaluator(model1, cls=cls).enumerate_summarized(
+ None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
+ )
+ )
+ if verbose:
+ print(f"[compare_onnx_execution] got {len(res1)} results")
+ print("[compare_onnx_execution] execute second model")
+ res2 = list(
+ YieldEvaluator(model2, cls=cls).enumerate_summarized(
+ None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
+ )
+ )
+ elif mode == "nodes":
+ # No execution.
+ if verbose:
+ print("[compare_onnx_execution] loading first model")
+ proto1 = load(model1) if isinstance(model1, str) else model1
+ if verbose:
+ print("[compare_onnx_execution] loading second model")
+ proto2 = load(model2) if isinstance(model2, str) else model2
+ res1 = list(_enumerate_result_no_execution(proto1))
+ res2 = list(_enumerate_result_no_execution(proto2))
+ else:
+ return
+
+ if verbose:
+ print(f"[compare_onnx_execution] got {len(res1)} results (first model)")
+ print(f"[compare_onnx_execution] got {len(res2)} results (second model)")
+ print("[compare_onnx_execution] compute edit distance")
+ dc = DistanceExecution()
+ _, align = dc.distance_sequence(res1, res2)
+ if verbose:
+ print(f"[compare_onnx_execution] got {len(align)} pairs")
+ print("[compare_onnx_execution] done")
+ return res1, res2, align, dc
diff --git a/onnx_array_api/reference/ops/op_constant_of_shape.py b/onnx_array_api/reference/ops/op_constant_of_shape.py
index 00c6989..a54bb5a 100644
--- a/onnx_array_api/reference/ops/op_constant_of_shape.py
+++ b/onnx_array_api/reference/ops/op_constant_of_shape.py
@@ -19,6 +19,8 @@ def _process(value):
cst = np.int64(cst)
elif isinstance(cst, float):
cst = np.float64(cst)
+ elif isinstance(cst, complex):
+ cst = np.complex128(cst)
elif cst is None:
cst = np.float32(0)
if not isinstance(
@@ -27,6 +29,8 @@ def _process(value):
np.float16,
np.float32,
np.float64,
+ np.complex64,
+ np.complex128,
np.int64,
np.int32,
np.int16,
diff --git a/onnx_array_api/reference/ops/op_fused_matmul.py b/onnx_array_api/reference/ops/op_fused_matmul.py
new file mode 100644
index 0000000..1ee0f04
--- /dev/null
+++ b/onnx_array_api/reference/ops/op_fused_matmul.py
@@ -0,0 +1,35 @@
+import numpy as np
+from onnx.reference.op_run import OpRun
+
+
+class FusedMatMul(OpRun):
+ op_domain = "com.microsoft"
+
+ def _run(
+ self,
+ A,
+ B,
+ alpha: float = 1,
+ transA: int = 0,
+ transB: int = 0,
+ transBatchA: int = 0,
+ transBatchB: int = 0,
+ ):
+ assert (
+ transBatchA == 0
+ ), f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}"
+ assert (
+ transBatchB == 0
+ ), f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}"
+ if transA:
+ perm = list(range(len(A.shape)))
+ dim = len(perm)
+ perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
+ A = np.transpose(A, perm)
+ if transB:
+ perm = list(range(len(B.shape)))
+ dim = len(perm)
+ perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2]
+ B = np.transpose(B, perm)
+ a = np.array(alpha, dtype=A.dtype)
+ return (np.matmul(A, B) * a,)
diff --git a/onnx_array_api/reference/ops/op_memcpy_host.py b/onnx_array_api/reference/ops/op_memcpy_host.py
new file mode 100644
index 0000000..ac365e7
--- /dev/null
+++ b/onnx_array_api/reference/ops/op_memcpy_host.py
@@ -0,0 +1,11 @@
+from onnx.reference.op_run import OpRun
+
+
+class MemcpyFromHost(OpRun):
+ def _run(self, x):
+ return (x,)
+
+
+class MemcpyToHost(OpRun):
+ def _run(self, x):
+ return (x,)
diff --git a/onnx_array_api/reference/ops/op_quick_gelu.py b/onnx_array_api/reference/ops/op_quick_gelu.py
new file mode 100644
index 0000000..e30c5ec
--- /dev/null
+++ b/onnx_array_api/reference/ops/op_quick_gelu.py
@@ -0,0 +1,23 @@
+import numpy as np
+from onnx.reference.op_run import OpRun
+
+
+def sigmoid(x): # type: ignore
+ if x > 0:
+ return 1 / (1 + np.exp(-x))
+ return np.exp(x) / (1 + np.exp(x))
+
+
+class QuickGelu(OpRun):
+ op_domain = "com.microsoft"
+
+ def __init__(self, onnx_node, run_params): # type: ignore
+ OpRun.__init__(self, onnx_node, run_params)
+ self.vf = np.vectorize(sigmoid)
+
+ def _run(self, X, alpha=1.0):
+ if len(X.shape) == 0:
+ return ((X * sigmoid(X * alpha)).astype(X.dtype),)
+ if X.size == 0:
+ return (X,)
+ return ((X * self.vf(X * alpha)).astype(X.dtype),)
diff --git a/onnx_array_api/reference/ops/op_scatter_elements.py b/onnx_array_api/reference/ops/op_scatter_elements.py
new file mode 100644
index 0000000..c4b0efa
--- /dev/null
+++ b/onnx_array_api/reference/ops/op_scatter_elements.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+from onnx.reference.op_run import OpRun
+
+
+def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
+ if reduction == "add":
+
+ def f(x, y):
+ return x + y
+
+ elif reduction == "min":
+
+ def f(x, y):
+ return min(x, y)
+
+ elif reduction == "max":
+
+ def f(x, y):
+ return max(x, y)
+
+ else:
+
+ def f(x, y):
+ return y
+
+ if axis < 0:
+ axis = data.ndim + axis
+
+ if len(data.shape) == 1 and axis == 0:
+ scattered = np.copy(data)
+ for pos, up in zip(indices, updates):
+ scattered[pos] = f(scattered[pos], up)
+ return scattered
+
+ if len(indices.shape) == 2:
+ scattered = np.copy(data)
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ scattered[indices[i, j], j] = f(
+ scattered[indices[i, j], j], updates[i, j]
+ )
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ scattered[i, indices[i, j]] = f(
+ scattered[i, indices[i, j]], updates[i, j]
+ )
+ return scattered
+
+ if len(indices.shape) == 3:
+ scattered = np.copy(data)
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in range(indices.shape[2]):
+ scattered[indices[i, j, k], j, k] = f(
+ scattered[indices[i, j, k], j, k], updates[i, j, k]
+ )
+ elif axis == 1:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in range(indices.shape[2]):
+ scattered[i, indices[i, j, k], k] = f(
+ scattered[i, indices[i, j, k], k], updates[i, j, k]
+ )
+ elif axis == 2:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in range(indices.shape[2]):
+ scattered[i, j, indices[i, j, k]] = f(
+ scattered[i, j, indices[i, j, k]], updates[i, j, k]
+ )
+ return scattered
+
+ if len(indices.shape) == 4:
+ scattered = np.copy(data)
+ if axis == 3:
+ for a in range(indices.shape[0]):
+ for i in range(indices.shape[1]):
+ for j in range(indices.shape[2]):
+ for k in range(indices.shape[3]):
+ scattered[a, i, j, indices[a, i, j, k]] = f(
+ scattered[a, i, j, indices[a, i, j, k]],
+ updates[a, i, j, k],
+ )
+ return scattered
+
+ raise RuntimeError(
+ f"Not implemented for indices.shape={indices.shape} and axis={axis}"
+ )
+
+
+class ScatterElements(OpRun):
+ def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
+ res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
+ return (res,)
diff --git a/onnx_array_api/tools/__init__.py b/onnx_array_api/tools/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/onnx_array_api/tools/__init__.py
@@ -0,0 +1 @@
+
diff --git a/onnx_array_api/tools/replace_constants.py b/onnx_array_api/tools/replace_constants.py
new file mode 100644
index 0000000..daa4ca8
--- /dev/null
+++ b/onnx_array_api/tools/replace_constants.py
@@ -0,0 +1,227 @@
+import numpy as np
+from onnx import FunctionProto, ModelProto, GraphProto, AttributeProto
+from onnx.helper import (
+ make_model,
+ set_model_props,
+ make_graph,
+ make_node,
+ make_attribute,
+ make_function,
+ tensor_dtype_to_np_dtype,
+)
+from onnx.numpy_helper import from_array
+
+
+def replace_initializer_by_constant_of_shape(
+ onx, threshold=128, op_type="ConstantOfShape", domain=""
+):
+ """
+ Replaces initializers by nodes *ConstantOfShape* to reduce
+ the size and still write a unit test.
+
+ :param onx: ModelProto
+ :param threshold: every initializer under this threshold is not impacted
+ :param op_type: replace by this node
+ :param domain: replace by this domain
+ :return: onx, modified ModelProto
+ """
+ if isinstance(onx, FunctionProto):
+ modified = False
+ new_nodes = []
+ for node in onx.node:
+ if node.op_type == "Constant":
+ from onnx_array_api.reference import ExtendedReferenceEvaluator
+
+ ref = ExtendedReferenceEvaluator(node)
+ cst = ref.run(None, {})[0]
+
+ size = np.prod(cst.shape)
+ if size <= threshold:
+ new_nodes.append(node)
+ continue
+
+ new_name = f"{node.output[0]}__SHAPE"
+ new_nodes.append(
+ make_node(
+ "Constant",
+ [],
+ [new_name],
+ value=from_array(
+ np.array(cst.shape, dtype=np.int64), name=new_name
+ ),
+ )
+ )
+ dtype = cst.dtype
+ assert op_type != "Constant"
+ new_nodes.append(
+ make_node(
+ op_type,
+ [new_name],
+ node.output,
+ value=from_array(np.array([0.5], dtype=dtype)),
+ domain=domain,
+ )
+ )
+ modified = True
+ continue
+
+ new_nodes.append(node)
+
+ if not modified:
+ return onx
+
+ onxf = make_function(
+ domain=onx.domain,
+ fname=onx.name,
+ inputs=onx.input,
+ outputs=onx.output,
+ nodes=new_nodes,
+ doc_string=onx.doc_string,
+ overload=onx.overload,
+ opset_imports=[],
+ )
+ if onx.opset_import:
+ onxf.opset_import.extend(onx.opset_import)
+ if onx.value_info:
+ onxf.value_info.extend(onx.value_info)
+ if onx.attribute:
+ onxf.attribute.extend(onx.attribute)
+ if onx.attribute_proto:
+ onxf.attribute_proto.extend(onx.attribute_proto)
+ return onxf
+
+ if isinstance(onx, ModelProto):
+ new_graph = replace_initializer_by_constant_of_shape(
+ onx.graph, threshold=threshold, op_type=op_type, domain=domain
+ )
+ new_functions = [
+ replace_initializer_by_constant_of_shape(
+ f, threshold=threshold, op_type=op_type, domain=domain
+ )
+ for f in onx.functions
+ ]
+ model = make_model(
+ new_graph,
+ functions=new_functions,
+ producer_name=onx.producer_name,
+ producer_version=onx.producer_version,
+ ir_version=onx.ir_version,
+ doc_string=onx.doc_string,
+ domain=onx.domain,
+ model_version=onx.model_version,
+ )
+ if len(onx.metadata_props) > 0: # pragma: no cover
+ values = {p.key: p.value for p in onx.metadata_props}
+ set_model_props(model, values)
+
+ del model.opset_import[:] # pylint: disable=E1101
+ for oimp in onx.opset_import:
+ op_set = model.opset_import.add() # pylint: disable=E1101
+ if oimp.domain == "" and oimp.version < 9:
+ raise RuntimeError(
+ f"ConstantOfShape was introduced in "
+ f"opset 9 but opset is {oimp.version}."
+ )
+ op_set.domain = oimp.domain
+ op_set.version = oimp.version
+ return model
+
+ if not isinstance(onx, GraphProto):
+ raise TypeError(f"onx should be a GraphProto as this stage not {type(onx)}.")
+
+ new_nodes = []
+ removed = set()
+ additional_inputs = []
+
+ new_inits = []
+ for init in onx.initializer:
+ dims = tuple(init.dims)
+ size = np.prod(dims)
+ if size <= threshold:
+ new_inits.append(init)
+ continue
+ new_name = f"{init.name}__SHAPE"
+ new_inits.append(
+ from_array(np.array(list(dims), dtype=np.int64), name=new_name)
+ )
+ dtype = tensor_dtype_to_np_dtype(init.data_type)
+ node = make_node(
+ op_type,
+ [new_name],
+ [init.name],
+ value=from_array(np.array([0.5], dtype=dtype)),
+ domain=domain,
+ )
+ new_nodes.append(node)
+ removed.add(init.name)
+
+ new_sparse_inits = []
+ for init in onx.sparse_initializer:
+ dims = tuple(init.dims)
+ size = np.prod(dims)
+ if size <= threshold:
+ new_sparse_inits.append(init)
+ continue
+ raise NotImplementedError(
+ f"This feature is not yet implemented for sparse initializer"
+ f"(name={init.name!r})."
+ )
+
+ for node in onx.node:
+ if node.op_type == "Constant":
+ from onnx_array_api.reference import ExtendedReferenceEvaluator
+
+ ref = ExtendedReferenceEvaluator(node)
+ cst = ref.run(None, {})[0]
+
+ size = np.prod(cst.shape)
+ if size <= threshold:
+ new_nodes.append(node)
+ continue
+
+ new_name = f"{node.output[0]}__SHAPE"
+ new_inits.append(
+ from_array(np.array(cst.shape, dtype=np.int64), name=new_name)
+ )
+ dtype = cst.dtype
+ new_nodes.append(
+ make_node(
+ op_type,
+ [new_name],
+ node.output,
+ value=from_array(np.array([0.5], dtype=dtype)),
+ domain=domain,
+ )
+ )
+ continue
+
+ modified = False
+ atts = []
+ for att in node.attribute:
+ if (
+ att.type == AttributeProto.GRAPH
+ and hasattr(att, "g")
+ and att.g is not None
+ ):
+ modified = True
+ g = replace_initializer_by_constant_of_shape(
+ att.g, threshold=threshold, op_type=op_type, domain=domain
+ )
+ att = make_attribute(att.name, g)
+ atts.append(att)
+ if modified:
+ new_node = make_node(node.op_type, node.input, node.output)
+ new_node.attribute.extend(atts)
+ new_nodes.append(new_node)
+ else:
+ new_nodes.append(node)
+
+ graph = make_graph(
+ new_nodes,
+ onx.name,
+ [i for i in onx.input if i.name not in removed] + additional_inputs,
+ onx.output,
+ initializer=new_inits,
+ sparse_initializer=new_sparse_inits,
+ )
+ return graph
diff --git a/onnx_array_api/translate_api/__init__.py b/onnx_array_api/translate_api/__init__.py
new file mode 100644
index 0000000..a9a8932
--- /dev/null
+++ b/onnx_array_api/translate_api/__init__.py
@@ -0,0 +1,94 @@
+from onnx import ModelProto
+from .translate import Translater
+from .inner_emitter import InnerEmitter, InnerEmitterShortInitializer
+from .builder_emitter import BuilderEmitter
+
+
+def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
+ """
+ Translates an ONNX proto into a code using :ref:`l-light-api`
+ to describe the ONNX graph.
+
+ :param proto: model to translate
+ :param single_line: as a single line or not
+ :param api: API to export into,
+ default is `"light"` and this is handle by class
+ :class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
+ another value is `"onnx"` which is the inner API implemented
+ in onnx package, `"builder"` follows the syntax for the
+ class :class:`onnx_array_api.graph_api.GraphBuilder`,
+ `"onnx-short"` replaces long initializer with random values
+ :return: code
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start
+ from onnx_array_api.translate_api import translate
+
+ onx = (
+ start()
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ code = translate(onx)
+ print(code)
+
+ The inner API from onnx package is also available.
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start
+ from onnx_array_api.translate_api import translate
+
+ onx = (
+ start()
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ code = translate(onx, api="onnx")
+ print(code)
+
+ The :class:`GraphBuilder
+ ` API returns this:
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start
+ from onnx_array_api.translate_api import translate
+
+ onx = (
+ start()
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ code = translate(onx, api="builder")
+ print(code)
+ """
+ if api == "light":
+ tr = Translater(proto)
+ return tr.export(single_line=single_line, as_str=True)
+ if api == "onnx":
+ tr = Translater(proto, emitter=InnerEmitter())
+ return tr.export(as_str=True)
+ if api == "onnx-short":
+ tr = Translater(proto, emitter=InnerEmitterShortInitializer())
+ return tr.export(as_str=True)
+ if api == "builder":
+ tr = Translater(proto, emitter=BuilderEmitter())
+ return tr.export(as_str=True)
+ raise ValueError(f"Unexpected value {api!r} for api.")
diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/translate_api/base_emitter.py
similarity index 53%
rename from onnx_array_api/light_api/emitter.py
rename to onnx_array_api/translate_api/base_emitter.py
index a1b0e40..e8d3811 100644
--- a/onnx_array_api/light_api/emitter.py
+++ b/onnx_array_api/translate_api/base_emitter.py
@@ -1,9 +1,8 @@
import inspect
-from typing import Any, Dict, List, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from enum import IntEnum
import numpy as np
from onnx import AttributeProto
-from .annotations import ELEMENT_TYPE_NAME
class EventType(IntEnum):
@@ -11,13 +10,25 @@ class EventType(IntEnum):
INPUT = 1
OUTPUT = 2
NODE = 3
- TO_ONNX = 4
+ TO_ONNX_MODEL = 4
BEGIN_GRAPH = 5
END_GRAPH = 6
BEGIN_FUNCTION = 7
END_FUNCTION = 8
INITIALIZER = 9
SPARSE_INITIALIZER = 10
+ FUNCTION_INPUT = 11
+ FUNCTION_OUTPUT = 12
+ FUNCTION_ATTRIBUTES = 13
+ TO_ONNX_FUNCTION = 14
+ BEGIN_SIGNATURE = 15
+ END_SIGNATURE = 16
+ BEGIN_RETURN = 17
+ END_RETURN = 18
+ BEGIN_FUNCTION_SIGNATURE = 19
+ END_FUNCTION_SIGNATURE = 20
+ BEGIN_FUNCTION_RETURN = 21
+ END_FUNCTION_RETURN = 22
@classmethod
def to_str(cls, self) -> str:
@@ -54,8 +65,11 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
if event == EventType.START:
return self._emit_start(**kwargs)
- if event == EventType.TO_ONNX:
- return self._emit_to_onnx(**kwargs)
+ if event == EventType.TO_ONNX_MODEL:
+ return self._emit_to_onnx_model(**kwargs)
+
+ if event == EventType.TO_ONNX_FUNCTION:
+ return self._emit_to_onnx_function(**kwargs)
if event == EventType.BEGIN_GRAPH:
return self._emit_begin_graph(**kwargs)
@@ -63,6 +77,45 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
if event == EventType.END_GRAPH:
return self._emit_end_graph(**kwargs)
+ if event == EventType.BEGIN_FUNCTION:
+ return self._emit_begin_function(**kwargs)
+
+ if event == EventType.BEGIN_FUNCTION_SIGNATURE:
+ return self._emit_begin_function_signature(**kwargs)
+
+ if event == EventType.END_FUNCTION_SIGNATURE:
+ return self._emit_end_function_signature(**kwargs)
+
+ if event == EventType.END_FUNCTION:
+ return self._emit_end_function(**kwargs)
+
+ if event == EventType.FUNCTION_INPUT:
+ return self._emit_function_input(**kwargs)
+
+ if event == EventType.FUNCTION_OUTPUT:
+ return self._emit_function_output(**kwargs)
+
+ if event == EventType.FUNCTION_ATTRIBUTES:
+ return self._emit_function_attributes(**kwargs)
+
+ if event == EventType.BEGIN_SIGNATURE:
+ return self._emit_begin_signature(**kwargs)
+
+ if event == EventType.END_SIGNATURE:
+ return self._emit_end_signature(**kwargs)
+
+ if event == EventType.BEGIN_RETURN:
+ return self._emit_begin_return(**kwargs)
+
+ if event == EventType.END_RETURN:
+ return self._emit_end_return(**kwargs)
+
+ if event == EventType.BEGIN_FUNCTION_RETURN:
+ return self._emit_begin_function_return(**kwargs)
+
+ if event == EventType.END_FUNCTION_RETURN:
+ return self._emit_end_function_return(**kwargs)
+
raise ValueError(f"Unexpected event {EventType.to_str(event)}.")
def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
@@ -104,11 +157,27 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
srows = ".".join(rows[:-1])
return [], f"g().{srows}"
+ if isinstance(value, tuple) and len(value) == 2 and value[1] is None:
+ # in a function, an attribute receiving a value from an attribute
+ v = value[0]
+ name = v.name
+ ref = v.ref_attr_name
+ dt = v.type
+ return [], self._make_attribute(name=name, ref_attr_name=ref, attr_type=dt)
+
raise ValueError(
f"Unable to render an attribute {type(v)}, "
f"attribute type={value[0].type}, "
f"dtype={getattr(v, 'dtype', '-')}, "
- f"shape={getattr(v, 'shape', '-')}, {value}."
+ f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, "
+ f"value={value!r}."
+ )
+
+ def _make_attribute(
+ self, name: str, attr_type: int, ref_attr_name: Optional[str] = None
+ ) -> str:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)
def join(self, rows: List[str], single_line: bool = False) -> str:
@@ -121,7 +190,12 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)
- def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]:
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
raise NotImplementedError(
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)
@@ -161,100 +235,46 @@ def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
)
+ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
-class Emitter(BaseEmitter):
- """
- Converts event into proper code.
- """
+ def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
- def join(self, rows: List[str], single_line: bool = False) -> str:
- "Join the rows"
- if single_line:
- return ".".join(rows)
- return "".join(["(\n ", "\n .".join(rows), "\n)"])
+ def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
- def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
- opsets = kwargs.get("opsets", {})
- opset = opsets.get("", None)
- if opset is not None:
- del opsets[""]
- args = []
- if opset:
- args.append(f"opset={opset}")
- if opsets:
- args.append(f"opsets={opsets}")
- return [f"start({', '.join(args)})"]
-
- def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]:
- return ["to_onnx()"]
+ def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
- def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError(
+ f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
+ )
+
+ def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
- def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
- def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
- name = kwargs["name"]
- value = kwargs["value"]
- repl = {"bool": "bool_", "object": "object_", "str": "str_"}
- sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
- return [
- f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
- f"rename({name!r})",
- ]
+ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
- def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
- name = kwargs["name"]
- elem_type = kwargs.get("elem_type", None)
- shape = kwargs.get("shape", None)
- if elem_type and shape:
- return [
- f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})"
- ]
- if elem_type:
- return [
- f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
- ]
- return [f"vin({name!r})"]
+ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
- def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
- inst = []
- if "name" in kwargs:
- name = kwargs["name"]
- inst.append(f"bring({name!r})")
- elem_type = kwargs.get("elem_type", None)
- shape = kwargs.get("shape", None)
- if elem_type and shape:
- inst.append(
- f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r})"
- )
- elif elem_type:
- inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})")
- else:
- inst.append("vout()")
- return inst
+ def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
- def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
- op_type = kwargs["op_type"]
- inputs = kwargs["inputs"]
- outputs = kwargs["outputs"]
- if kwargs.get("domain", "") != "":
- domain = kwargs["domain"]
- op_type = f"{domain}.{op_type}"
- atts = kwargs.get("atts", {})
- args = []
- for k, v in atts.items():
- before, vatt = self.render_attribute_value(v)
- if before:
- raise NotImplementedError("Graph attribute not supported yet.")
- args.append(f"{k}={vatt}")
-
- str_inputs = ", ".join([f"{i!r}" for i in inputs])
- inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"]
- if len(outputs) == 1:
- inst.append(f"rename({outputs[0]!r})")
- else:
- str_outputs = ", ".join([f"{o!r}" for o in outputs])
- inst.append(f"rename({str_outputs})")
- return inst
+ def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
diff --git a/onnx_array_api/translate_api/builder_emitter.py b/onnx_array_api/translate_api/builder_emitter.py
new file mode 100644
index 0000000..19dd7f9
--- /dev/null
+++ b/onnx_array_api/translate_api/builder_emitter.py
@@ -0,0 +1,242 @@
+from typing import Any, Dict, List
+from onnx import TensorProto
+from onnx.numpy_helper import to_array
+from .base_emitter import BaseEmitter
+
+_types = {
+ TensorProto.DOUBLE: "DOUBLE",
+ TensorProto.FLOAT: "FLOAT",
+ TensorProto.FLOAT16: "FLOAT16",
+ TensorProto.INT64: "INT64",
+ TensorProto.INT32: "INT32",
+ TensorProto.INT16: "INT16",
+ TensorProto.UINT64: "UINT64",
+ TensorProto.UINT32: "UINT32",
+ TensorProto.UINT16: "UINT16",
+ TensorProto.STRING: "STRING",
+ TensorProto.BOOL: "BOOL",
+}
+
+
+def _itype_to_string(itype: int) -> str:
+ return _types[itype]
+
+
+class BuilderEmitter(BaseEmitter):
+ """
+ Converts event into proper code.
+ """
+
+ def __init__(self, make_model_function: str = ""):
+ super().__init__()
+ self.make_model_function = make_model_function
+
+ def join(self, rows: List[str], single_line: bool = False) -> str:
+ "Join the rows"
+ assert (
+ not single_line
+ ), f"The emitter {type(self)} does not work with single_line=True."
+ return "\n".join(rows)
+
+ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.opsets = kwargs.get("opsets", {})
+ self.ir_version = kwargs.get("ir_version", None)
+ self.function_calls = []
+ return []
+
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ inps = ", ".join(["g.op", *[f'"{i}"' for i in self.inputs]])
+ inputs = []
+ for inp, stype, shape in self.inputs_full_:
+ inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype}, {shape})')
+ outputs = []
+ for inp, stype, shape in self.outputs_full_:
+ outputs.append(
+ f'g.make_tensor_output("{inp}", TensorProto.{stype}, '
+ f"{shape}, is_dimension=False, indexed=False)"
+ )
+ rows = [
+ "",
+ (
+ f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
+ if self.ir_version
+ else f"GraphBuilder({self.opsets})"
+ ),
+ *inputs,
+ f"{self.name}({inps})",
+ *outputs,
+ *self.function_calls,
+ "model = g.to_onnx()",
+ ]
+ if self.make_model_function:
+ rows = [
+ "",
+ "",
+ f'def {self.make_model_function}() -> "ModelProto":',
+ *[" " + _ for _ in rows[1:]],
+ " return model",
+ "",
+ "",
+ f"model = {self.make_model_function}()",
+ ]
+ return rows
+
+ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.inputs = []
+ self.inputs_full = []
+ self.outputs = []
+ self.inits = []
+ self.inputs_full_ = []
+ self.outputs_full_ = []
+ self.name = kwargs.get("name", "make_graph")
+ return []
+
+ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ init = kwargs["init"]
+ if isinstance(init, TensorProto):
+ assert (
+ kwargs["name"] == init.name
+ ), f"Name mismatch init.name={init.name!r}, name={kwargs['name']!r}"
+ self.inits.append(init)
+ return []
+ raise AssertionError(f"Unsupported type for an initializer {type(init)}")
+
+ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ itype = kwargs.get("elem_type", 0)
+ shape = kwargs.get("shape", None)
+ name = self._clean_result_name(name)
+ if itype == 0:
+ inp = name or "X"
+ else:
+ if shape is None:
+ inp = f'{name}: "{_itype_to_string(itype)}"'
+ else:
+ inp = (
+ f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
+ )
+ self.inputs_full.append(inp)
+ self.inputs.append(name)
+ self.inputs_full_.append((name, _itype_to_string(itype), shape))
+ return []
+
+ def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ rows = ["", f"def {self.name}(", ' op: "GraphBuilder",']
+ for i in self.inputs_full:
+ rows.append(f" {i},")
+ rows.append("):")
+ for init in self.inits:
+ val = to_array(init)
+ stype = str(val.dtype).split(".")[-1]
+ name = self._clean_result_name(init.name)
+ rows.append(f" {name} = np.array({val.tolist()}, dtype=np.{stype})")
+ return rows
+
+ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ outs = ", ".join(self.outputs)
+ return [f" return {outs}"]
+
+ def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ name = self._clean_result_name(name)
+ itype = kwargs.get("elem_type", 0)
+ shape = kwargs.get("shape", None)
+ self.outputs.append(name)
+ self.outputs_full_.append((name, _itype_to_string(itype), shape))
+ return [f' op.Identity({name}, outputs=["{name}"])']
+
+ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ domain = kwargs.get("domain", "")
+ atts = kwargs.get("atts", {})
+ args = []
+ for k, v in atts.items():
+ before, vatt = self.render_attribute_value(v)
+ if before:
+ raise NotImplementedError("Graph attribute not supported yet.")
+ args.append(f"{k}={vatt}")
+
+ cleaned_outputs = list(map(self._clean_result_name, outputs))
+ outs = ", ".join(cleaned_outputs)
+ inps = ", ".join(map(self._clean_result_name, inputs))
+ op_type = self._emit_node_type(op_type, domain)
+ # Let's add output names to make it easier to debug.
+ soutputs = f", outputs={cleaned_outputs}"
+ sdomain = soutputs if not domain else f", domain={domain!r}{soutputs}"
+ if args:
+ sargs = ", ".join(args)
+ if inps:
+ row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})"
+ else:
+ row = f" {outs} = op.{op_type}({sargs}{sdomain})"
+ else:
+ row = f" {outs} = op.{op_type}({inps}{sdomain})"
+ return [row]
+
+ def _clean_result_name(self, name):
+ return name
+
+ def _emit_node_type(self, op_type, domain):
+ return op_type
+
+ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_inputs = []
+ self.f_outputs = []
+ self.f_inits = []
+ self.f_name = kwargs["name"]
+ self.f_domain = kwargs["domain"]
+ self.f_attributes = []
+ self.f_opsets = kwargs["opsets"]
+ return []
+
+ def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_call_name = f"make_{self.f_domain}_{self.f_name}"
+ return [
+ "",
+ "",
+ f'def {self.f_call_name}(g: "GraphBuilder"):',
+ f" gr = GraphBuilder({self.f_opsets}, as_function=True)",
+ *[f" {name} = gr.make_tensor_input({name!r})" for name in self.f_inputs],
+ " op = gr.op",
+ ]
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return [" return gr"]
+
+ def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_inputs.append(kwargs["name"])
+ return []
+
+ def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.f_outputs.append(kwargs["name"])
+ return []
+
+ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
+ raise NotImplementedError("Function attribute are not implemented yet.")
+
+ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ self.function_calls.append(f"{self.f_call_name}(g)")
+ return [
+ *[f" gr.make_tensor_output({name})" for name in self.f_outputs],
+ " g.add_function(builder=gr)",
+ ]
+
+ def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
diff --git a/onnx_array_api/translate_api/inner_emitter.py b/onnx_array_api/translate_api/inner_emitter.py
new file mode 100644
index 0000000..de63dcc
--- /dev/null
+++ b/onnx_array_api/translate_api/inner_emitter.py
@@ -0,0 +1,266 @@
+from typing import Any, Dict, List, Optional, Tuple
+from onnx import AttributeProto
+from ..annotations import ELEMENT_TYPE_NAME
+from .base_emitter import BaseEmitter
+from .translate import Translater
+
+
+class InnerEmitter(BaseEmitter):
+ """
+ Converts event into proper code.
+ """
+
+ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
+ """
+ Renders an attribute value into a string.
+
+ :param value: value to converter
+ :return: rows to append before, actual value
+ """
+ if value[0].type == AttributeProto.GRAPH:
+ tr = Translater(value[0].g, emitter=self)
+ rows = tr.export(as_str=False, single_line=False)
+ new_rows = [f"def _make_local_graph_{value[0].name}():"]
+ for line in rows:
+ if "make_model" in line:
+ break
+ new_rows.append(" " + line)
+ new_rows.append(" return graph")
+ new_rows.append(f"{value[0].name} = _make_local_graph_{value[0].name}()")
+ return new_rows, value[0].name
+
+ return super().render_attribute_value(value)
+
+ def _make_attribute(
+ self, name: str, attr_type: int, ref_attr_name: Optional[str] = None
+ ) -> str:
+ if ref_attr_name is None:
+ raise NotImplementedError(
+ f"Cannot create attribute with name={name!r}, attr_type={attr_type}."
+ )
+ return (
+ f"make_ref_attribute(key={name!r}, attr_type={attr_type}, "
+ f"ref_attr_name={ref_attr_name!r})"
+ )
+
+ def join(self, rows: List[str], single_line: bool = False) -> str:
+ "Returns the separators. `single_line` is unused."
+ return "\n".join(rows)
+
+ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = ["opset_imports = ["]
+ opsets = kwargs.get("opsets", {})
+ for k, v in opsets.items():
+ lines.append(f" make_opsetid({k!r}, {v!r}),")
+ lines.append("]")
+ return lines
+
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "model = make_model(",
+ " graph,",
+ " functions=functions,",
+ " opset_imports=opset_imports",
+ ")",
+ ]
+ return lines
+
+ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "inputs = []",
+ "outputs = []",
+ "nodes = []",
+ "initializers = []",
+ "sparse_initializers = []",
+ "functions = []",
+ ]
+ return lines
+
+ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs.get("name", "noname")
+ lines = [
+ "graph = make_graph(",
+ " nodes,",
+ f" {name!r},",
+ " inputs,",
+ " outputs,",
+ " initializers,",
+ " sparse_initializer=sparse_initializers,",
+ ")",
+ ]
+ return lines
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ value = kwargs["value"]
+ repl = {"bool": "bool_", "object": "object_", "str": "str_"}
+ fra = "from_array"
+ sdtype = repl.get(str(value.dtype), str(value.dtype))
+ if sdtype.startswith("("):
+ from onnx.reference.custom_element_types import float8e4m3fn
+
+ if sdtype == str(float8e4m3fn):
+ sdtype = "float8e4m3fn"
+ fra = "from_array_extended"
+ else:
+ raise NotImplementedError(f"Unexpected dtype={sdtype}.")
+ else:
+ sdtype = f"np.{sdtype}"
+
+ return [
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array({value.tolist()}, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
+
+ def _emit_io(self, container: str, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ return [
+ f"{container}.append(make_tensor_value_info({name!r}, "
+ f"TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape={shape!r}))"
+ ]
+ if elem_type:
+ return [
+ f"{container}.append(make_tensor_value_info({name!r}, "
+ f"TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, shape=[]))"
+ ]
+ return [
+ f"{container}.append(make_tensor_value_info({name!r}, "
+ f"TensorProto.UNDEFINED, []))"
+ ]
+
+ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return self._emit_io("inputs", **kwargs)
+
+ def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return self._emit_io("outputs", **kwargs)
+
+ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ if kwargs.get("domain", "") != "":
+ domain = kwargs["domain"]
+
+ before_lines = []
+ lines = [
+ "nodes.append(",
+ " make_node_extended(",
+ f" {op_type!r},",
+ f" {inputs},",
+ f" {outputs},",
+ ]
+ domain = kwargs.get("domain", "")
+ if domain:
+ lines.append(f" domain={domain!r},")
+ atts = kwargs.get("atts", {})
+ for k, v in atts.items():
+ before, value = self.render_attribute_value(v)
+ before_lines.extend(before)
+ lines.append(f" {k}={value},")
+ lines[-1] = lines[-1][:-1]
+ lines.extend([" )", ")"])
+ return before_lines + lines
+
+ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "",
+ f"name_f = {kwargs['name']!r}",
+ f"domain_f = {kwargs['domain']!r}",
+ "nodes = []",
+ "inputs = []",
+ "outputs = []",
+ "atts = []",
+ ]
+ return lines
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return [f"inputs.append({kwargs['name']!r})"]
+
+ def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return [f"outputs.append({kwargs['name']!r})"]
+
+ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
+ atts = kwargs["attributes"]
+ if isinstance(atts, list) and all(isinstance(t, str) for t in atts):
+ return [f"atts.extend({atts!r})"]
+ raise NotImplementedError(f"Unable to process function attributes {atts!r}.")
+
+ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ lines = [
+ "functions.append(",
+ " make_function(",
+ " domain_f, ",
+ " name_f, ",
+ " inputs, ",
+ " outputs, ",
+ " nodes, ",
+ " attributes=atts, ",
+ " opset_imports=opset_imports,",
+ " )",
+ ")",
+ ]
+ return lines
+
+
+class InnerEmitterShortInitializer(InnerEmitter):
+ """
+ Converts event into proper code.
+ Initializer are replaced by random values if too big.
+ """
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ value = kwargs["value"]
+ repl = {"bool": "bool_", "object": "object_", "str": "str_"}
+ fra = "from_array"
+ sdtype = repl.get(str(value.dtype), str(value.dtype))
+ if sdtype.startswith("("):
+ from onnx.reference.custom_element_types import float8e4m3fn
+
+ if sdtype == str(float8e4m3fn):
+ sdtype = "float8e4m3fn"
+ fra = "from_array_extended"
+ else:
+ raise NotImplementedError(f"Unexpected dtype={sdtype}.")
+ else:
+ sdtype = f"np.{sdtype}"
+ if value.size <= 16:
+ return [
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array({value.tolist()}, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
+ if "int" in sdtype:
+ return [
+ f"value = np.random.randint(0, 10, size={value.shape})"
+ f".astype({sdtype})",
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array(value, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
+ return [
+ f"value = np.random.randn({', '.join(map(str,value.shape))})"
+ f".astype({sdtype})",
+ "initializers.append(",
+ f" {fra}(",
+ f" np.array(value, dtype={sdtype}),",
+ f" name={name!r}",
+ " )",
+ ")",
+ ]
diff --git a/onnx_array_api/translate_api/light_emitter.py b/onnx_array_api/translate_api/light_emitter.py
new file mode 100644
index 0000000..9c58830
--- /dev/null
+++ b/onnx_array_api/translate_api/light_emitter.py
@@ -0,0 +1,106 @@
+from typing import Any, Dict, List
+from ..annotations import ELEMENT_TYPE_NAME
+from .base_emitter import BaseEmitter
+
+
+class LightEmitter(BaseEmitter):
+ """
+ Converts event into proper code.
+ """
+
+ def join(self, rows: List[str], single_line: bool = False) -> str:
+ "Join the rows"
+ if single_line:
+ return ".".join(rows)
+ return "".join(["(\n ", "\n .".join(rows), "\n)"])
+
+ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
+ opsets = kwargs.get("opsets", {})
+ opset = opsets.get("", None)
+ if opset is not None:
+ del opsets[""]
+ args = []
+ if opset:
+ args.append(f"opset={opset}")
+ if opsets:
+ args.append(f"opsets={opsets}")
+ return [f"start({', '.join(args)})"]
+
+ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return ["to_onnx()"]
+
+ def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
+ return []
+
+ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ value = kwargs["value"]
+ repl = {"bool": "bool_", "object": "object_", "str": "str_"}
+ sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
+ return [
+ f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
+ f"rename({name!r})",
+ ]
+
+ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
+ name = kwargs["name"]
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
+ f"shape={shape!r})"
+ ]
+ if elem_type:
+ return [
+ f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
+ ]
+ return [f"vin({name!r})"]
+
+ def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
+ inst = []
+ if "name" in kwargs:
+ name = kwargs["name"]
+ inst.append(f"bring({name!r})")
+ elem_type = kwargs.get("elem_type", None)
+ shape = kwargs.get("shape", None)
+ if elem_type and shape:
+ inst.append(
+ f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
+ f"shape={shape!r})"
+ )
+ elif elem_type:
+ inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})")
+ else:
+ inst.append("vout()")
+ return inst
+
+ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
+ op_type = kwargs["op_type"]
+ inputs = kwargs["inputs"]
+ outputs = kwargs["outputs"]
+ if kwargs.get("domain", "") != "":
+ domain = kwargs["domain"]
+ op_type = f"{domain}.{op_type}"
+ atts = kwargs.get("atts", {})
+ args = []
+ for k, v in atts.items():
+ before, vatt = self.render_attribute_value(v)
+ if before:
+ raise NotImplementedError("Graph attribute not supported yet.")
+ args.append(f"{k}={vatt}")
+
+ str_inputs = ", ".join([f"{i!r}" for i in inputs])
+ inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"]
+ if len(outputs) == 1:
+ inst.append(f"rename({outputs[0]!r})")
+ else:
+ str_outputs = ", ".join([f"{o!r}" for o in outputs])
+ inst.append(f"rename({str_outputs})")
+ return inst
diff --git a/onnx_array_api/translate_api/make_helper.py b/onnx_array_api/translate_api/make_helper.py
new file mode 100644
index 0000000..8b2703c
--- /dev/null
+++ b/onnx_array_api/translate_api/make_helper.py
@@ -0,0 +1,65 @@
+from typing import Any, Optional, Sequence
+from onnx import AttributeProto, NodeProto
+from onnx.helper import make_attribute
+
+
+def make_ref_attribute(
+ key: str, attr_type: int, ref_attr_name: Optional[str] = None
+) -> AttributeProto:
+ """
+ Creates an attribute.
+
+ :param key: atttribute name
+ :param attr_type: attribute type
+ :param ref_attr_name: if not None, link this attribute
+ to a function attribute
+ :return: attribute
+ """
+ att = AttributeProto()
+ att.name = key
+ att.type = attr_type
+ att.ref_attr_name = ref_attr_name
+ return att
+
+
+def make_node_extended(
+ op_type: str,
+ inputs: Sequence[str],
+ outputs: Sequence[str],
+ name: Optional[str] = None,
+ doc_string: Optional[str] = None,
+ domain: Optional[str] = None,
+ **kwargs: Any,
+) -> NodeProto:
+ """
+ Constructs a NodeProto.
+
+ :param op_type: The name of the operator to construct
+ :param inputs: list of input names
+ :param outputs: list of output names
+ :param name: optional unique identifier for NodeProto
+ :param doc_string: optional documentation string for NodeProto
+ :param domain: optional domain for NodeProto.
+ If it's None, we will just use default domain (which is empty)
+ :param kwargs: the attributes of the node.
+ :return: node proto
+ """
+ node = NodeProto()
+ node.op_type = op_type
+ node.input.extend(inputs)
+ node.output.extend(outputs)
+ if name:
+ node.name = name
+ if doc_string:
+ node.doc_string = doc_string
+ if domain is not None:
+ node.domain = domain
+ if kwargs:
+ for key, value in sorted(kwargs.items()):
+ if value is None:
+ continue
+ if isinstance(value, AttributeProto):
+ node.attribute.append(value)
+ else:
+ node.attribute.append(make_attribute(key, value))
+ return node
diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/translate_api/translate.py
similarity index 67%
rename from onnx_array_api/light_api/translate.py
rename to onnx_array_api/translate_api/translate.py
index a61ce24..81d515a 100644
--- a/onnx_array_api/light_api/translate.py
+++ b/onnx_array_api/translate_api/translate.py
@@ -2,7 +2,9 @@
import numpy as np
from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto
from onnx.numpy_helper import to_array
-from .emitter import EventType, Emitter
+from ..reference import to_array_extended
+from .base_emitter import EventType
+from .light_emitter import LightEmitter
class Translater:
@@ -13,10 +15,10 @@ class Translater:
def __init__(
self,
proto: Union[ModelProto, FunctionProto, GraphProto],
- emitter: Optional[Emitter] = None,
+ emitter: Optional[LightEmitter] = None,
):
self.proto_ = proto
- self.emitter = emitter or Emitter()
+ self.emitter = emitter or LightEmitter()
def __repr__(self) -> str:
return f"{self.__class__.__name__}(<{type(self.proto_)})"
@@ -30,14 +32,22 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
:return: list of instructions
"""
rows = []
+ last_event = None
if isinstance(self.proto_, ModelProto):
opsets = {d.domain: d.version for d in self.proto_.opset_import}
- rows.extend(self.emitter(EventType.START, opsets=opsets))
+ rows.extend(
+ self.emitter(
+ EventType.START, opsets=opsets, ir_version=self.proto_.ir_version
+ )
+ )
inputs = self.proto_.graph.input
outputs = self.proto_.graph.output
nodes = self.proto_.graph.node
initializers = self.proto_.graph.initializer
sparse_initializers = self.proto_.graph.sparse_initializer
+ attributes = []
+ last_event = EventType.TO_ONNX_MODEL
+ is_function = False
elif isinstance(self.proto_, (FunctionProto, GraphProto)):
inputs = self.proto_.input
outputs = self.proto_.output
@@ -48,30 +58,56 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
else:
initializers = []
sparse_initializers = []
+ attributes = (
+ self.proto_.attribute if hasattr(self.proto_, "attribute") else []
+ )
+ is_function = isinstance(self.proto_, FunctionProto)
+ last_event = (
+ EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL
+ )
else:
raise ValueError(f"Unexpected type {type(self.proto_)} for proto.")
if sparse_initializers:
raise NotImplementedError("Sparse initializer not supported yet.")
- rows.extend(
- self.emitter(
- EventType.BEGIN_FUNCTION
- if isinstance(self.proto_, FunctionProto)
- else EventType.BEGIN_GRAPH
+ if is_function:
+ rows.extend(
+ self.emitter(
+ EventType.BEGIN_FUNCTION,
+ name=self.proto_.name,
+ domain=self.proto_.domain,
+ opsets={d.domain: d.version for d in self.proto_.opset_import},
+ )
+ )
+ elif isinstance(self.proto_, GraphProto):
+ rows.extend(self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.name))
+ else:
+ rows.extend(
+ self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name)
)
- )
for i in initializers:
rows.extend(
self.emitter(
- EventType.INITIALIZER, name=i.name, init=i, value=to_array(i)
+ EventType.INITIALIZER,
+ name=i.name,
+ init=i,
+ value=to_array_extended(i),
)
)
+ rows.extend(
+ self.emitter(
+ EventType.BEGIN_FUNCTION_SIGNATURE
+ if is_function
+ else EventType.BEGIN_SIGNATURE
+ )
+ )
+
for i in inputs:
- if isinstance(i, str):
- rows.extend(self.emitter(EventType.INPUT, name=i))
+ if is_function:
+ rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i))
else:
rows.extend(
self.emitter(
@@ -85,6 +121,19 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
)
)
+ if is_function and attributes:
+ rows.extend(
+ self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes))
+ )
+
+ rows.extend(
+ self.emitter(
+ EventType.END_FUNCTION_SIGNATURE
+ if is_function
+ else EventType.END_SIGNATURE
+ )
+ )
+
for node in nodes:
atts = self.extract_attributes(node)
rows.extend(
@@ -98,9 +147,17 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
)
)
+ rows.extend(
+ self.emitter(
+ EventType.BEGIN_FUNCTION_RETURN
+ if is_function
+ else EventType.BEGIN_RETURN
+ )
+ )
+
for o in outputs:
- if isinstance(o, str):
- rows.extend(self.emitter(EventType.INPUT, name=o))
+ if is_function:
+ rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o))
else:
rows.extend(
self.emitter(
@@ -113,23 +170,32 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
),
)
)
+
+ rows.extend(
+ self.emitter(
+ EventType.END_FUNCTION_RETURN if is_function else EventType.END_RETURN
+ )
+ )
+
if isinstance(self.proto_, (GraphProto, FunctionProto)):
name = self.proto_.name
else:
name = self.proto_.graph.name
+
rows.extend(
self.emitter(
- EventType.END_FUNCTION
- if isinstance(self.proto_, FunctionProto)
- else EventType.END_GRAPH,
+ EventType.END_FUNCTION if is_function else EventType.END_GRAPH,
name=name,
)
)
if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0:
- raise NotImplementedError("Local functions are not yet implemented.")
+ for fu in self.proto_.functions:
+ cl = self.__class__(fu, self.emitter)
+ text = cl.export(False, single_line=False)
+ rows.extend(text)
- rows.extend(self.emitter(EventType.TO_ONNX))
+ rows.extend(self.emitter(last_event))
if as_str:
return self.emitter.join(rows, single_line=single_line)
return rows
diff --git a/onnx_array_api/validation/docs.py b/onnx_array_api/validation/docs.py
index d1a8422..c5f937f 100644
--- a/onnx_array_api/validation/docs.py
+++ b/onnx_array_api/validation/docs.py
@@ -30,7 +30,9 @@ def make_euclidean(
n2 = oh.make_node("Pow", ["dxy", "two"], ["dxy2"])
n3 = oh.make_node("ReduceSum", ["dxy2"], [output_name])
graph = oh.make_graph([n1, n2, n3], "euclidian", [X, Y], [Z], [two])
- model = oh.make_model(graph, opset_imports=[oh.make_opsetid("", opset)])
+ model = oh.make_model(
+ graph, opset_imports=[oh.make_opsetid("", opset)], ir_version=9
+ )
return model
diff --git a/onnx_array_api/validation/f8.py b/onnx_array_api/validation/f8.py
index ecd68f8..13b778d 100644
--- a/onnx_array_api/validation/f8.py
+++ b/onnx_array_api/validation/f8.py
@@ -9,8 +9,6 @@ class UndefinedCastError(FloatingPointError):
Unable to case a number.
"""
- pass
-
def display_int(ival, sign=1, exponent=8, mantissa=23):
"""
@@ -317,25 +315,23 @@ def fe5m2_to_float32(ival: int, fn: bool = False, uz: bool = False) -> float:
class CastFloat8Sets:
values_e4m3fn = list(
sorted(
- (fe4m3_to_float32_float(i), i) for i in range(0, 256) if i not in (255, 127)
+ (fe4m3_to_float32_float(i), i) for i in range(256) if i not in (255, 127)
)
)
values_e4m3fnuz = list(
- sorted(
- (fe4m3_to_float32_float(i, uz=True), i) for i in range(0, 256) if i != 0x80
- )
+ sorted((fe4m3_to_float32_float(i, uz=True), i) for i in range(256) if i != 0x80)
)
values_e5m2 = list(
sorted(
(fe5m2_to_float32_float(i), i)
- for i in range(0, 256)
+ for i in range(256)
if i not in {253, 254, 255, 125, 126, 127}
)
)
values_e5m2fnuz = list(
sorted(
(fe5m2_to_float32_float(i, fn=True, uz=True), i)
- for i in range(0, 256)
+ for i in range(256)
if i != 0x80
)
)
diff --git a/onnx_array_api/validation/tools.py b/onnx_array_api/validation/tools.py
index 6cd1da3..cbb02c1 100644
--- a/onnx_array_api/validation/tools.py
+++ b/onnx_array_api/validation/tools.py
@@ -20,7 +20,7 @@
def randomize_proto(
- onx: Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto]
+ onx: Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto],
) -> Union[ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto]:
"""
Randomizes float initializers or constant nodes.
diff --git a/pyproject.toml b/pyproject.toml
index 4101adf..a465006 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,19 +11,46 @@ exclude = [
# Same as Black.
line-length = 88
-[tool.ruff.mccabe]
-# Unlike Flake8, default to a complexity level of 10.
-max-complexity = 10
+[tool.ruff.lint]
+select = [
+ "B", # flake8-bugbear
+ "C4", # flake8-comprehensions
+ #"D", # pydocstyle
+ "E", # pycodestyle
+ "F", # Pyflakes
+ "G", # flake8-logging-format
+ #"I", # isort
+ "ISC", # flake8-implicit-str-concat
+ "LOG", # flake8-logging
+ #"N", # pep8-naming
+ #"NPY", # modern numpy
+ #"PERF", # Perflint
+ "PIE", # flake8-pie
+ "PYI", # flake8-pyi
+ "RUF", # Ruff-specific rules
+ "SIM", # flake8-simplify
+ "SLOT", # flake8-slot
+ "T10", # flake8-debugger
+ #"TID", # Disallow relative imports
+ #"TRY", # flake8-try-except-raise
+ "UP", # pyupgrade
+ "W", # pycodestyle
+ "YTT", # flake8-2020
+]
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
+"**" = ["B905", "C401", "C408", "C413", "PYI041", "RUF012", "RUF100", "RUF010", "SIM108", "SIM910", "SIM110", "SIM102", "SIM114", "SIM103", "UP015", "UP027", "UP031", "UP034", "UP032", "UP006", "UP035", "UP007", "UP038"]
+"**/plot*.py" = ["B018"]
"_doc/examples/plot_first_example.py" = ["E402", "F811"]
"_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]
"onnx_array_api/array_api/_onnx_common.py" = ["F821"]
+"onnx_array_api/graph_api/__init__.py" = ["F401"]
"onnx_array_api/light_api/__init__.py" = ["F401"]
"onnx_array_api/light_api/_op_var.py" = ["F821"]
"onnx_array_api/light_api/_op_vars.py" = ["F821"]
-"onnx_array_api/light_api/annotations.py" = ["F821"]
+"onnx_array_api/annotations.py" = ["F821"]
"onnx_array_api/light_api/model.py" = ["F821"]
+"onnx_array_api/translate_api/__init__.py" = ["F401"]
"onnx_array_api/npx/__init__.py" = ["F401", "F403"]
"onnx_array_api/npx/npx_functions.py" = ["F821"]
"onnx_array_api/npx/npx_functions_test.py" = ["F821"]
@@ -32,4 +59,5 @@ max-complexity = 10
"onnx_array_api/profiling.py" = ["E731"]
"onnx_array_api/reference/__init__.py" = ["F401"]
"_unittests/ut_npx/test_npx.py" = ["F821"]
+"_unittests/ut_translate_api/test_translate_classic.py" = ["E501"]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 5804529..de339f5 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,3 +1,5 @@
+array_api_compat
+array_api_strict
autopep8
black
coverage
@@ -11,7 +13,7 @@ lightgbm
matplotlib
ml-dtypes
git+https://github.com/onnx/onnxmltools.git
-onnxruntime>=1.16.1
+onnxruntime>=1.17.0
openpyxl
packaging
pandas
diff --git a/requirements.txt b/requirements.txt
index 4680cfc..4396e32 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,3 @@
-array_api_compat
numpy
onnx>=1.15.0
scipy
diff --git a/setup.py b/setup.py
index bc4e87e..b4cced8 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
import os
from setuptools import setup
@@ -63,9 +62,10 @@
"Operating System :: Unix",
"Operating System :: MacOS",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
],
)