diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst
index 554f796..4c61d99 100644
--- a/CHANGELOGS.rst
+++ b/CHANGELOGS.rst
@@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++
+* :pr:`42`: first sketch for a very simple API to create onnx graph in one or two lines
* :pr:`27`: add function from_array_extended to convert
an array to a TensorProto, including bfloat16 and float 8 types
* :pr:`24`: add ExtendedReferenceEvaluator to support scenario
diff --git a/README.rst b/README.rst
index cc3efd6..4525fe9 100644
--- a/README.rst
+++ b/README.rst
@@ -114,5 +114,4 @@ It supports eager mode as well:
The library is released on
`pypi/onnx-array-api `_
and its documentation is published at
-`(Numpy) Array API for ONNX
-`_.
+`(Numpy) Array API for ONNX `_.
diff --git a/_doc/api/index.rst b/_doc/api/index.rst
index d52b616..181a459 100644
--- a/_doc/api/index.rst
+++ b/_doc/api/index.rst
@@ -7,6 +7,7 @@ API
:maxdepth: 1
array_api
+ light_api
npx_core_api
npx_functions
npx_jit_eager
diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst
new file mode 100644
index 0000000..9c46e3a
--- /dev/null
+++ b/_doc/api/light_api.rst
@@ -0,0 +1,32 @@
+========================
+onnx_array_api.light_api
+========================
+
+start
+=====
+
+.. autofunction:: onnx_array_api.light_api.start
+
+OnnxGraph
+=========
+
+.. autoclass:: onnx_array_api.light_api.OnnxGraph
+ :members:
+
+BaseVar
+=======
+
+.. autoclass:: onnx_array_api.light_api.var.BaseVar
+ :members:
+
+Var
+===
+
+.. autoclass:: onnx_array_api.light_api.Var
+ :members:
+
+Vars
+====
+
+.. autoclass:: onnx_array_api.light_api.Vars
+ :members:
diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py
new file mode 100644
index 0000000..3feaa2a
--- /dev/null
+++ b/_unittests/ut_light_api/test_light_api.py
@@ -0,0 +1,407 @@
+import unittest
+from typing import Callable, Optional
+import numpy as np
+from onnx import ModelProto
+from onnx.defs import (
+ get_all_schemas_with_history,
+ onnx_opset_version,
+ OpSchema,
+ get_schema,
+ SchemaError,
+)
+from onnx.reference import ReferenceEvaluator
+from onnx_array_api.ext_test_case import ExtTestCase
+from onnx_array_api.light_api import start, OnnxGraph, Var
+from onnx_array_api.light_api._op_var import OpsVar
+from onnx_array_api.light_api._op_vars import OpsVars
+
+OPSET_API = min(19, onnx_opset_version() - 1)
+
+
+def make_method(schema: OpSchema) -> Optional[Callable]:
+ if schema.min_output != schema.max_output:
+ return None
+
+ kwargs = []
+ names = []
+ defaults_none = []
+ for v in schema.attributes.values():
+ names.append(v.name)
+ if v.default_value is None:
+ kwargs.append(f"{v.name}=None")
+ elif v.type.value == OpSchema.AttrType.FLOAT:
+ kwargs.append(f"{v.name}: float={v.default_value.f}")
+ elif v.type.value == OpSchema.AttrType.INT:
+ kwargs.append(f"{v.name}: int={v.default_value.i}")
+ elif v.type.value == OpSchema.AttrType.INTS:
+ kwargs.append(f"{v.name}: Optional[List[int]]=None")
+ defaults_none.append(
+ f" {v.name} = {v.name} or {v.default_value.ints}"
+ )
+ elif v.type.value == OpSchema.AttrType.STRING:
+ kwargs.append(f"{v.name}: str={v.default_value.s!r}")
+ else:
+ raise AssertionError(
+ f"Operator {schema.domain}:{schema.name} has attribute "
+ f"{v.name!r} with type {v.type}."
+ )
+
+ if max(schema.min_output, schema.max_output) > 1:
+ ann = "Vars"
+ else:
+ ann = "Var"
+ code = [f' def {schema.name}(self, {", ".join(kwargs)})->"{ann}":']
+ if defaults_none:
+ code.extend(defaults_none)
+
+ n_inputs = schema.max_input
+ eol = ", ".join(f"{n}={n}" for n in names)
+ if schema.domain == "":
+ if n_inputs == 1:
+ code.append(f' return self.make_node("{schema.name}", self, {eol})')
+ else:
+ code.append(
+ f' return self.make_node("{schema.name}", *self.vars_, {eol})'
+ )
+ else:
+ raise AssertionError(
+ f"Not implemented yet for operator {schema.domain}:{schema.name}."
+ )
+
+ return "\n".join(code)
+
+
+class TestLightApi(ExtTestCase):
+ def list_ops_missing(self, n_inputs):
+ schemas = {}
+ for schema in get_all_schemas_with_history():
+ if (
+ schema.domain != ""
+ or "Sequence" in schema.name
+ or "Optional" in schema.name
+ ):
+ continue
+ key = schema.domain, schema.name
+ if key not in schemas or schemas[key].since_version < schema.since_version:
+ schemas[key] = schema
+ expected = set(_[1] for _ in list(sorted(schemas)))
+ missing = []
+ for ex in expected:
+ if (
+ not hasattr(Var, ex)
+ and not hasattr(OpsVar, ex)
+ and not hasattr(OpsVars, ex)
+ ):
+ missing.append(ex)
+ if missing:
+ methods = []
+ new_missing = []
+ for m in sorted(missing):
+ try:
+ schema = get_schema(m, OPSET_API)
+ except SchemaError:
+ continue
+ if m in {
+ "Constant",
+ "ConstantOfShape",
+ "If",
+ "Max",
+ "MaxPool",
+ "Mean",
+ "Min",
+ "StringNormalizer",
+ "Sum",
+ "TfIdfVectorizer",
+ "Unique",
+ # 2
+ "BatchNormalization",
+ "Dropout",
+ "GRU",
+ "LSTM",
+ "LayerNormalization",
+ "Loop",
+ "RNN",
+ "Scan",
+ "SoftmaxCrossEntropyLoss",
+ "Split",
+ }:
+ continue
+ if schema.min_input == schema.max_input == 1:
+ if n_inputs != 1:
+ continue
+ else:
+ if n_inputs == 1:
+ continue
+ code = make_method(schema)
+ if code is not None:
+ methods.append(code)
+ methods.append("")
+ new_missing.append(m)
+ text = "\n".join(methods)
+ if len(new_missing) > 0:
+ raise AssertionError(
+ f"n_inputs={n_inputs}: missing method for operators "
+ f"{new_missing}\n{text}"
+ )
+
+ def test_list_ops_missing(self):
+ self.list_ops_missing(1)
+ self.list_ops_missing(2)
+
+ def test_list_ops_uni(self):
+ schemas = {}
+ for schema in get_all_schemas_with_history():
+ if (
+ schema.domain != ""
+ or "Sequence" in schema.name
+ or "Optional" in schema.name
+ ):
+ continue
+ if (
+ schema.min_input
+ == schema.max_input
+ == 1
+ == schema.max_output
+ == schema.min_output
+ and len(schema.attributes) == 0
+ ):
+ key = schema.domain, schema.name
+ if (
+ key not in schemas
+ or schemas[key].since_version < schema.since_version
+ ):
+ schemas[key] = schema
+ expected = set(_[1] for _ in list(sorted(schemas)))
+ for ex in expected:
+ self.assertHasAttr(OpsVar, ex)
+
+ def test_list_ops_bi(self):
+ schemas = {}
+ for schema in get_all_schemas_with_history():
+ if (
+ schema.domain != ""
+ or "Sequence" in schema.name
+ or "Optional" in schema.name
+ ):
+ continue
+ if (
+ (schema.min_input == schema.max_input == 2)
+ and (1 == schema.max_output == schema.min_output)
+ and len(schema.attributes) == 0
+ ):
+ key = schema.domain, schema.name
+ if (
+ key not in schemas
+ or schemas[key].since_version < schema.since_version
+ ):
+ schemas[key] = schema
+ expected = set(_[1] for _ in list(sorted(schemas)))
+ for ex in expected:
+ self.assertHasAttr(OpsVars, ex)
+
+ def test_neg(self):
+ onx = start()
+ self.assertIsInstance(onx, OnnxGraph)
+ r = repr(onx)
+ self.assertEqual("OnnxGraph()", r)
+ v = start().vin("X")
+ self.assertIsInstance(v, Var)
+ self.assertEqual(["X"], v.parent.input_names)
+ s = str(v)
+ self.assertEqual("X:FLOAT", s)
+ onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(-a, got)
+
+ def test_exp(self):
+ onx = start().vin("X").Exp().rename("Y").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Exp", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(np.exp(a), got)
+
+ def test_transpose(self):
+ onx = (
+ start()
+ .vin("X")
+ .reshape((-1, 1))
+ .Transpose(perm=[1, 0])
+ .rename("Y")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ self.assertIn("Transpose", str(onx))
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a})[0]
+ self.assertEqualArray(a.reshape((-1, 1)).T, got)
+
+ def test_add(self):
+ onx = start()
+ onx = (
+ start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "Y": a + 1})[0]
+ self.assertEqualArray(a * 2 + 1, got)
+
+ def test_mul(self):
+ onx = start()
+ onx = (
+ start().vin("X").vin("Y").bring("X", "Y").Mul().rename("Z").vout().to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "Y": a + 1})[0]
+ self.assertEqualArray(a * (a + 1), got)
+
+ def test_add_constant(self):
+ onx = start()
+ onx = (
+ start()
+ .vin("X")
+ .cst(np.array([1], dtype=np.float32), "one")
+ .bring("X", "one")
+ .Add()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "Y": a + 1})[0]
+ self.assertEqualArray(a + 1, got)
+
+ def test_left_bring(self):
+ onx = start()
+ onx = (
+ start()
+ .vin("X")
+ .cst(np.array([1], dtype=np.float32), "one")
+ .left_bring("X")
+ .Add()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "Y": a + 1})[0]
+ self.assertEqualArray(a + 1, got)
+
+ def test_right_bring(self):
+ onx = (
+ start()
+ .vin("S")
+ .vin("X")
+ .right_bring("S")
+ .Reshape()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "S": np.array([-1], dtype=np.int64)})[0]
+ self.assertEqualArray(a.ravel(), got)
+
+ def test_reshape_1(self):
+ onx = (
+ start()
+ .vin("X")
+ .vin("S")
+ .bring("X", "S")
+ .Reshape()
+ .rename("Z")
+ .vout()
+ .to_onnx()
+ )
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "S": np.array([-1], dtype=np.int64)})[0]
+ self.assertEqualArray(a.ravel(), got)
+
+ def test_reshape_2(self):
+ x = start().vin("X").vin("S").v("X")
+ self.assertIsInstance(x, Var)
+ self.assertEqual(x.name, "X")
+ g = start()
+ g.vin("X").vin("S").v("X").reshape("S").rename("Z").vout()
+ self.assertEqual(["Z"], g.output_names)
+ onx = start().vin("X").vin("S").v("X").reshape("S").rename("Z").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "S": np.array([-1], dtype=np.int64)})[0]
+ self.assertEqualArray(a.ravel(), got)
+
+ def test_operator_float(self):
+ for f in [
+ lambda x, y: x + y,
+ lambda x, y: x - y,
+ lambda x, y: x * y,
+ lambda x, y: x / y,
+ lambda x, y: x == y,
+ lambda x, y: x < y,
+ lambda x, y: x <= y,
+ lambda x, y: x > y,
+ lambda x, y: x >= y,
+ lambda x, y: x != y,
+ lambda x, y: x @ y,
+ ]:
+ g = start()
+ x = g.vin("X")
+ y = g.vin("Y")
+ onx = f(x, y).rename("Z").vout().to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.float32)
+ got = ref.run(None, {"X": a, "Y": a + 1})[0]
+ self.assertEqualArray(f(a, a + 1), got)
+
+ def test_operator_int(self):
+ for f in [
+ lambda x, y: x % y,
+ lambda x, y: x**y,
+ ]:
+ g = start()
+ x = g.vin("X", np.int64)
+ y = g.vin("Y", np.int64)
+ onx = f(x, y).rename("Z").vout(np.int64).to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = np.arange(10).astype(np.int64)
+ got = ref.run(None, {"X": a, "Y": a + 1})[0]
+ self.assertEqualArray(f(a, a + 1), got)
+
+ def test_operator_bool(self):
+ for f in [
+ lambda x, y: x != y,
+ ]:
+ g = start()
+ x = g.vin("X", np.bool_)
+ y = g.vin("Y", np.bool_)
+ onx = f(x, y).rename("Z").vout(np.bool_).to_onnx()
+ self.assertIsInstance(onx, ModelProto)
+ ref = ReferenceEvaluator(onx)
+ a = (np.arange(10).astype(np.int64) % 2).astype(np.bool_)
+ b = (np.arange(10).astype(np.int64) % 3).astype(np.bool_)
+ got = ref.run(None, {"X": a, "Y": b})[0]
+ self.assertEqualArray(f(a, b), got)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py
index f23f18c..f4b3c4d 100644
--- a/onnx_array_api/array_api/__init__.py
+++ b/onnx_array_api/array_api/__init__.py
@@ -107,7 +107,8 @@ def wrap(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
else:
b = a
new_args.append(b)
- return f(TEagerTensor, *new_args, **kwargs)
+ res = f(TEagerTensor, *new_args, **kwargs)
+ return res
wrap.__doc__ = f.__doc__
return wrap
diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py
index 898fc64..6e8ee6d 100644
--- a/onnx_array_api/array_api/_onnx_common.py
+++ b/onnx_array_api/array_api/_onnx_common.py
@@ -96,10 +96,14 @@ def asarray(
if all(map(lambda x: isinstance(x, bool), a)):
v = TEagerTensor(np.array(a, dtype=np.bool_))
elif all(map(lambda x: isinstance(x, int), a)):
- if all(map(lambda x: x >= 0, a)):
- v = TEagerTensor(np.array(a, dtype=np.uint64))
- else:
- v = TEagerTensor(np.array(a, dtype=np.int64))
+ try:
+ cvt = np.array(a, dtype=np.int64)
+ except OverflowError as e:
+ if all(map(lambda x: x >= 0, a)):
+ cvt = np.array(a, dtype=np.uint64)
+ else:
+ raise e
+ v = TEagerTensor(cvt)
else:
v = TEagerTensor(np.array(a))
elif isinstance(a, np.ndarray):
diff --git a/onnx_array_api/ext_test_case.py b/onnx_array_api/ext_test_case.py
index ab72c57..6726008 100644
--- a/onnx_array_api/ext_test_case.py
+++ b/onnx_array_api/ext_test_case.py
@@ -214,6 +214,10 @@ def assertEmpty(self, value: Any):
return
raise AssertionError(f"value is not empty: {value!r}.")
+ def assertHasAttr(self, cls: type, name: str):
+ if not hasattr(cls, name):
+ raise AssertionError(f"Class {cls} has no attribute {name!r}.")
+
def assertNotEmpty(self, value: Any):
if value is None:
raise AssertionError(f"value is empty: {value!r}.")
diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py
new file mode 100644
index 0000000..272ea0d
--- /dev/null
+++ b/onnx_array_api/light_api/__init__.py
@@ -0,0 +1,41 @@
+from typing import Dict, Optional
+from .model import OnnxGraph
+from .var import Var, Vars
+
+
+def start(
+ opset: Optional[int] = None,
+ opsets: Optional[Dict[str, int]] = None,
+ is_function: bool = False,
+) -> OnnxGraph:
+ """
+ Starts an onnx model.
+
+ :param opset: main opset version
+ :param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
+ :param opsets: others opsets as a dictionary
+ :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
+
+ A very simple model:
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start
+
+ onx = start().vin("X").Neg().rename("Y").vout().to_onnx()
+ print(onx)
+
+ Another with operator Add:
+
+ .. runpython::
+ :showcode:
+
+ from onnx_array_api.light_api import start
+
+ onx = (
+ start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
+ )
+ print(onx)
+ """
+ return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py
new file mode 100644
index 0000000..6b511c5
--- /dev/null
+++ b/onnx_array_api/light_api/_op_var.py
@@ -0,0 +1,259 @@
+from typing import List, Optional
+
+
+class OpsVar:
+ """
+ Operators taking only one input.
+ """
+
+ def ArgMax(
+ self, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "ArgMax",
+ self,
+ axis=axis,
+ keepdims=keepdims,
+ select_last_index=select_last_index,
+ )
+
+ def ArgMin(
+ self, axis: int = 0, keepdims: int = 1, select_last_index: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "ArgMin",
+ self,
+ axis=axis,
+ keepdims=keepdims,
+ select_last_index=select_last_index,
+ )
+
+ def AveragePool(
+ self,
+ auto_pad: str = b"NOTSET",
+ ceil_mode: int = 0,
+ count_include_pad: int = 0,
+ dilations: Optional[List[int]] = None,
+ kernel_shape: Optional[List[int]] = None,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ kernel_shape = kernel_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "AveragePool",
+ self,
+ auto_pad=auto_pad,
+ ceil_mode=ceil_mode,
+ count_include_pad=count_include_pad,
+ dilations=dilations,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ def Bernoulli(self, dtype: int = 0, seed: float = 0.0) -> "Var":
+ return self.make_node("Bernoulli", self, dtype=dtype, seed=seed)
+
+ def BlackmanWindow(self, output_datatype: int = 1, periodic: int = 1) -> "Var":
+ return self.make_node(
+ "BlackmanWindow", self, output_datatype=output_datatype, periodic=periodic
+ )
+
+ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
+ return self.make_node("Cast", self, saturate=saturate, to=to)
+
+ def Celu(self, alpha: float = 1.0) -> "Var":
+ return self.make_node("Celu", self, alpha=alpha)
+
+ def DepthToSpace(self, blocksize: int = 0, mode: str = b"DCR") -> "Var":
+ return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)
+
+ def DynamicQuantizeLinear(
+ self,
+ ) -> "Vars":
+ return self.make_node(
+ "DynamicQuantizeLinear",
+ self,
+ )
+
+ def Elu(self, alpha: float = 1.0) -> "Var":
+ return self.make_node("Elu", self, alpha=alpha)
+
+ def EyeLike(self, dtype: int = 0, k: int = 0) -> "Var":
+ return self.make_node("EyeLike", self, dtype=dtype, k=k)
+
+ def Flatten(self, axis: int = 1) -> "Var":
+ return self.make_node("Flatten", self, axis=axis)
+
+ def GlobalLpPool(self, p: int = 2) -> "Var":
+ return self.make_node("GlobalLpPool", self, p=p)
+
+ def HammingWindow(self, output_datatype: int = 1, periodic: int = 1) -> "Var":
+ return self.make_node(
+ "HammingWindow", self, output_datatype=output_datatype, periodic=periodic
+ )
+
+ def HannWindow(self, output_datatype: int = 1, periodic: int = 1) -> "Var":
+ return self.make_node(
+ "HannWindow", self, output_datatype=output_datatype, periodic=periodic
+ )
+
+ def HardSigmoid(
+ self, alpha: float = 0.20000000298023224, beta: float = 0.5
+ ) -> "Var":
+ return self.make_node("HardSigmoid", self, alpha=alpha, beta=beta)
+
+ def Hardmax(self, axis: int = -1) -> "Var":
+ return self.make_node("Hardmax", self, axis=axis)
+
+ def IsInf(self, detect_negative: int = 1, detect_positive: int = 1) -> "Var":
+ return self.make_node(
+ "IsInf",
+ self,
+ detect_negative=detect_negative,
+ detect_positive=detect_positive,
+ )
+
+ def LRN(
+ self,
+ alpha: float = 9.999999747378752e-05,
+ beta: float = 0.75,
+ bias: float = 1.0,
+ size: int = 0,
+ ) -> "Var":
+ return self.make_node("LRN", self, alpha=alpha, beta=beta, bias=bias, size=size)
+
+ def LeakyRelu(self, alpha: float = 0.009999999776482582) -> "Var":
+ return self.make_node("LeakyRelu", self, alpha=alpha)
+
+ def LogSoftmax(self, axis: int = -1) -> "Var":
+ return self.make_node("LogSoftmax", self, axis=axis)
+
+ def LpNormalization(self, axis: int = -1, p: int = 2) -> "Var":
+ return self.make_node("LpNormalization", self, axis=axis, p=p)
+
+ def LpPool(
+ self,
+ auto_pad: str = b"NOTSET",
+ ceil_mode: int = 0,
+ dilations: Optional[List[int]] = None,
+ kernel_shape: Optional[List[int]] = None,
+ p: int = 2,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ kernel_shape = kernel_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "LpPool",
+ self,
+ auto_pad=auto_pad,
+ ceil_mode=ceil_mode,
+ dilations=dilations,
+ kernel_shape=kernel_shape,
+ p=p,
+ pads=pads,
+ strides=strides,
+ )
+
+ def MeanVarianceNormalization(self, axes: Optional[List[int]] = None) -> "Var":
+ axes = axes or [0, 2, 3]
+ return self.make_node("MeanVarianceNormalization", self, axes=axes)
+
+ def Multinomial(
+ self, dtype: int = 6, sample_size: int = 1, seed: float = 0.0
+ ) -> "Var":
+ return self.make_node(
+ "Multinomial", self, dtype=dtype, sample_size=sample_size, seed=seed
+ )
+
+ def RandomNormalLike(
+ self, dtype: int = 0, mean: float = 0.0, scale: float = 1.0, seed: float = 0.0
+ ) -> "Var":
+ return self.make_node(
+ "RandomNormalLike", self, dtype=dtype, mean=mean, scale=scale, seed=seed
+ )
+
+ def RandomUniformLike(
+ self, dtype: int = 0, high: float = 1.0, low: float = 0.0, seed: float = 0.0
+ ) -> "Var":
+ return self.make_node(
+ "RandomUniformLike", self, dtype=dtype, high=high, low=low, seed=seed
+ )
+
+ def Selu(
+ self, alpha: float = 1.6732631921768188, gamma: float = 1.0507010221481323
+ ) -> "Var":
+ return self.make_node("Selu", self, alpha=alpha, gamma=gamma)
+
+ def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
+ return self.make_node("Shrink", self, bias=bias, lambd=lambd)
+
+ def Softmax(self, axis: int = -1) -> "Var":
+ return self.make_node("Softmax", self, axis=axis)
+
+ def SpaceToDepth(self, blocksize: int = 0) -> "Var":
+ return self.make_node("SpaceToDepth", self, blocksize=blocksize)
+
+ def ThresholdedRelu(self, alpha: float = 1.0) -> "Var":
+ return self.make_node("ThresholdedRelu", self, alpha=alpha)
+
+ def Transpose(self, perm: Optional[List[int]] = None) -> "Var":
+ perm = perm or []
+ return self.make_node("Transpose", self, perm=perm)
+
+
+def _complete():
+ ops_to_add = [
+ "Abs",
+ "Acos",
+ "Acosh",
+ "Asin",
+ "Asinh",
+ "Atan",
+ "Atanh",
+ "BitwiseNot",
+ "Ceil",
+ "Cos",
+ "Cosh",
+ "Det",
+ "Erf",
+ "Exp",
+ "Floor",
+ "GlobalAveragePool",
+ "GlobalMaxPool",
+ "HardSwish",
+ "Identity",
+ "IsNaN",
+ "Log",
+ "Mish",
+ "Neg",
+ "NonZero",
+ "Not",
+ "Reciprocal",
+ "Relu",
+ "Round",
+ "Shape",
+ "Sigmoid",
+ "Sign",
+ "Sin",
+ "Sinh",
+ "Size",
+ "Softplus",
+ "Softsign",
+ "Sqrt",
+ "Tan",
+ "Tanh",
+ ]
+ for name in ops_to_add:
+ if hasattr(OpsVar, name):
+ continue
+ setattr(OpsVar, name, lambda self, op_type=name: self.make_node(op_type, self))
+
+
+_complete()
diff --git a/onnx_array_api/light_api/_op_vars.py b/onnx_array_api/light_api/_op_vars.py
new file mode 100644
index 0000000..77dbac6
--- /dev/null
+++ b/onnx_array_api/light_api/_op_vars.py
@@ -0,0 +1,573 @@
+from typing import List, Optional
+
+
+class OpsVars:
+ """
+ Operators taking multiple inputs.
+ """
+
+ def BitShift(self, direction: str = b"") -> "Var":
+ return self.make_node("BitShift", *self.vars_, direction=direction)
+
+ def CenterCropPad(self, axes: Optional[List[int]] = None) -> "Var":
+ axes = axes or []
+ return self.make_node("CenterCropPad", *self.vars_, axes=axes)
+
+ def Clip(
+ self,
+ ) -> "Var":
+ return self.make_node(
+ "Clip",
+ *self.vars_,
+ )
+
+ def Col2Im(
+ self,
+ dilations: Optional[List[int]] = None,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "Col2Im", *self.vars_, dilations=dilations, pads=pads, strides=strides
+ )
+
+ def Compress(self, axis: int = 0) -> "Var":
+ return self.make_node("Compress", *self.vars_, axis=axis)
+
+ def Concat(self, axis: int = 0) -> "Var":
+ return self.make_node("Concat", *self.vars_, axis=axis)
+
+ def Conv(
+ self,
+ auto_pad: str = b"NOTSET",
+ dilations: Optional[List[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[List[int]] = None,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ kernel_shape = kernel_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "Conv",
+ *self.vars_,
+ auto_pad=auto_pad,
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ def ConvInteger(
+ self,
+ auto_pad: str = b"NOTSET",
+ dilations: Optional[List[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[List[int]] = None,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ kernel_shape = kernel_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "ConvInteger",
+ *self.vars_,
+ auto_pad=auto_pad,
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ def ConvTranspose(
+ self,
+ auto_pad: str = b"NOTSET",
+ dilations: Optional[List[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[List[int]] = None,
+ output_padding: Optional[List[int]] = None,
+ output_shape: Optional[List[int]] = None,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ kernel_shape = kernel_shape or []
+ output_padding = output_padding or []
+ output_shape = output_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "ConvTranspose",
+ *self.vars_,
+ auto_pad=auto_pad,
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ output_padding=output_padding,
+ output_shape=output_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ def CumSum(self, exclusive: int = 0, reverse: int = 0) -> "Var":
+ return self.make_node(
+ "CumSum", *self.vars_, exclusive=exclusive, reverse=reverse
+ )
+
+ def DFT(self, axis: int = 1, inverse: int = 0, onesided: int = 0) -> "Var":
+ return self.make_node(
+ "DFT", *self.vars_, axis=axis, inverse=inverse, onesided=onesided
+ )
+
+ def DeformConv(
+ self,
+ dilations: Optional[List[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[List[int]] = None,
+ offset_group: int = 1,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ kernel_shape = kernel_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "DeformConv",
+ *self.vars_,
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ offset_group=offset_group,
+ pads=pads,
+ strides=strides,
+ )
+
+ def DequantizeLinear(self, axis: int = 1) -> "Var":
+ return self.make_node("DequantizeLinear", *self.vars_, axis=axis)
+
+ def Einsum(self, equation: str = b"") -> "Var":
+ return self.make_node("Einsum", *self.vars_, equation=equation)
+
+ def Gather(self, axis: int = 0) -> "Var":
+ return self.make_node("Gather", *self.vars_, axis=axis)
+
+ def GatherElements(self, axis: int = 0) -> "Var":
+ return self.make_node("GatherElements", *self.vars_, axis=axis)
+
+ def Gemm(
+ self, alpha: float = 1.0, beta: float = 1.0, transA: int = 0, transB: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "Gemm", *self.vars_, alpha=alpha, beta=beta, transA=transA, transB=transB
+ )
+
+ def GridSample(
+ self,
+ align_corners: int = 0,
+ mode: str = b"bilinear",
+ padding_mode: str = b"zeros",
+ ) -> "Var":
+ return self.make_node(
+ "GridSample",
+ *self.vars_,
+ align_corners=align_corners,
+ mode=mode,
+ padding_mode=padding_mode,
+ )
+
+ def GroupNormalization(
+ self, epsilon: float = 9.999999747378752e-06, num_groups: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "GroupNormalization", *self.vars_, epsilon=epsilon, num_groups=num_groups
+ )
+
+ def InstanceNormalization(self, epsilon: float = 9.999999747378752e-06) -> "Var":
+ return self.make_node("InstanceNormalization", *self.vars_, epsilon=epsilon)
+
+ def MatMulInteger(
+ self,
+ ) -> "Var":
+ return self.make_node(
+ "MatMulInteger",
+ *self.vars_,
+ )
+
+ def MaxRoiPool(
+ self, pooled_shape: Optional[List[int]] = None, spatial_scale: float = 1.0
+ ) -> "Var":
+ pooled_shape = pooled_shape or []
+ return self.make_node(
+ "MaxRoiPool",
+ *self.vars_,
+ pooled_shape=pooled_shape,
+ spatial_scale=spatial_scale,
+ )
+
+ def MaxUnpool(
+ self,
+ kernel_shape: Optional[List[int]] = None,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ kernel_shape = kernel_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "MaxUnpool",
+ *self.vars_,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ def MelWeightMatrix(self, output_datatype: int = 1) -> "Var":
+ return self.make_node(
+ "MelWeightMatrix", *self.vars_, output_datatype=output_datatype
+ )
+
+ def Mod(self, fmod: int = 0) -> "Var":
+ return self.make_node("Mod", *self.vars_, fmod=fmod)
+
+ def NegativeLogLikelihoodLoss(
+ self, ignore_index: int = 0, reduction: str = b"mean"
+ ) -> "Var":
+ return self.make_node(
+ "NegativeLogLikelihoodLoss",
+ *self.vars_,
+ ignore_index=ignore_index,
+ reduction=reduction,
+ )
+
+ def NonMaxSuppression(self, center_point_box: int = 0) -> "Var":
+ return self.make_node(
+ "NonMaxSuppression", *self.vars_, center_point_box=center_point_box
+ )
+
+ def OneHot(self, axis: int = -1) -> "Var":
+ return self.make_node("OneHot", *self.vars_, axis=axis)
+
+ def Pad(self, mode: str = b"constant") -> "Var":
+ return self.make_node("Pad", *self.vars_, mode=mode)
+
+ def QLinearConv(
+ self,
+ auto_pad: str = b"NOTSET",
+ dilations: Optional[List[int]] = None,
+ group: int = 1,
+ kernel_shape: Optional[List[int]] = None,
+ pads: Optional[List[int]] = None,
+ strides: Optional[List[int]] = None,
+ ) -> "Var":
+ dilations = dilations or []
+ kernel_shape = kernel_shape or []
+ pads = pads or []
+ strides = strides or []
+ return self.make_node(
+ "QLinearConv",
+ *self.vars_,
+ auto_pad=auto_pad,
+ dilations=dilations,
+ group=group,
+ kernel_shape=kernel_shape,
+ pads=pads,
+ strides=strides,
+ )
+
+ def QLinearMatMul(
+ self,
+ ) -> "Var":
+ return self.make_node(
+ "QLinearMatMul",
+ *self.vars_,
+ )
+
+ def QuantizeLinear(self, axis: int = 1, saturate: int = 1) -> "Var":
+ return self.make_node(
+ "QuantizeLinear", *self.vars_, axis=axis, saturate=saturate
+ )
+
+ def RandomNormal(
+ self,
+ dtype: int = 1,
+ mean: float = 0.0,
+ scale: float = 1.0,
+ seed: float = 0.0,
+ shape: Optional[List[int]] = None,
+ ) -> "Var":
+ shape = shape or []
+ return self.make_node(
+ "RandomNormal",
+ *self.vars_,
+ dtype=dtype,
+ mean=mean,
+ scale=scale,
+ seed=seed,
+ shape=shape,
+ )
+
+ def RandomUniform(
+ self,
+ dtype: int = 1,
+ high: float = 1.0,
+ low: float = 0.0,
+ seed: float = 0.0,
+ shape: Optional[List[int]] = None,
+ ) -> "Var":
+ shape = shape or []
+ return self.make_node(
+ "RandomUniform",
+ *self.vars_,
+ dtype=dtype,
+ high=high,
+ low=low,
+ seed=seed,
+ shape=shape,
+ )
+
+ def Range(
+ self,
+ ) -> "Var":
+ return self.make_node(
+ "Range",
+ *self.vars_,
+ )
+
+ def ReduceL1(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceL1",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceL2(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceL2",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceLogSum(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceLogSum",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceLogSumExp(
+ self, keepdims: int = 1, noop_with_empty_axes: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "ReduceLogSumExp",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceMax(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceMax",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceMean(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceMean",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceMin(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceMin",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceProd(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceProd",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceSum(self, keepdims: int = 1, noop_with_empty_axes: int = 0) -> "Var":
+ return self.make_node(
+ "ReduceSum",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def ReduceSumSquare(
+ self, keepdims: int = 1, noop_with_empty_axes: int = 0
+ ) -> "Var":
+ return self.make_node(
+ "ReduceSumSquare",
+ *self.vars_,
+ keepdims=keepdims,
+ noop_with_empty_axes=noop_with_empty_axes,
+ )
+
+ def Resize(
+ self,
+ antialias: int = 0,
+ axes: Optional[List[int]] = None,
+ coordinate_transformation_mode: str = b"half_pixel",
+ cubic_coeff_a: float = -0.75,
+ exclude_outside: int = 0,
+ extrapolation_value: float = 0.0,
+ keep_aspect_ratio_policy: str = b"stretch",
+ mode: str = b"nearest",
+ nearest_mode: str = b"round_prefer_floor",
+ ) -> "Var":
+ axes = axes or []
+ return self.make_node(
+ "Resize",
+ *self.vars_,
+ antialias=antialias,
+ axes=axes,
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ cubic_coeff_a=cubic_coeff_a,
+ exclude_outside=exclude_outside,
+ extrapolation_value=extrapolation_value,
+ keep_aspect_ratio_policy=keep_aspect_ratio_policy,
+ mode=mode,
+ nearest_mode=nearest_mode,
+ )
+
+ def RoiAlign(
+ self,
+ coordinate_transformation_mode: str = b"half_pixel",
+ mode: str = b"avg",
+ output_height: int = 1,
+ output_width: int = 1,
+ sampling_ratio: int = 0,
+ spatial_scale: float = 1.0,
+ ) -> "Var":
+ return self.make_node(
+ "RoiAlign",
+ *self.vars_,
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ mode=mode,
+ output_height=output_height,
+ output_width=output_width,
+ sampling_ratio=sampling_ratio,
+ spatial_scale=spatial_scale,
+ )
+
+ def STFT(self, onesided: int = 1) -> "Var":
+ return self.make_node("STFT", *self.vars_, onesided=onesided)
+
+ def Scatter(self, axis: int = 0) -> "Var":
+ return self.make_node("Scatter", *self.vars_, axis=axis)
+
+ def ScatterElements(self, axis: int = 0, reduction: str = b"none") -> "Var":
+ return self.make_node(
+ "ScatterElements", *self.vars_, axis=axis, reduction=reduction
+ )
+
+ def ScatterND(self, reduction: str = b"none") -> "Var":
+ return self.make_node("ScatterND", *self.vars_, reduction=reduction)
+
+ def Slice(
+ self,
+ ) -> "Var":
+ return self.make_node(
+ "Slice",
+ *self.vars_,
+ )
+
+ def TopK(self, axis: int = -1, largest: int = 1, sorted: int = 1) -> "Vars":
+ return self.make_node(
+ "TopK", *self.vars_, axis=axis, largest=largest, sorted=sorted
+ )
+
+ def Trilu(self, upper: int = 1) -> "Var":
+ return self.make_node("Trilu", *self.vars_, upper=upper)
+
+ def Upsample(self, mode: str = b"nearest") -> "Var":
+ return self.make_node("Upsample", *self.vars_, mode=mode)
+
+ def Where(
+ self,
+ ) -> "Var":
+ return self.make_node(
+ "Where",
+ *self.vars_,
+ )
+
+
+def _complete():
+ ops_to_add = [
+ "Add",
+ "And",
+ "BitwiseAnd",
+ "BitwiseOr",
+ "BitwiseXor",
+ "CastLike",
+ "Div",
+ "Equal",
+ "Expand",
+ "GatherND",
+ "Greater",
+ "GreaterOrEqual",
+ "Less",
+ "LessOrEqual",
+ "MatMul",
+ "Mul",
+ "Or",
+ "PRelu",
+ "Pow",
+ "Reshape",
+ "StringConcat",
+ "Sub",
+ "Tile",
+ "Unsqueeze",
+ "Xor",
+ ]
+
+ for name in ops_to_add:
+ if hasattr(OpsVars, name):
+ continue
+ setattr(
+ OpsVars,
+ name,
+ lambda self, op_type=name: self._check_nin(2).make_node(
+ op_type, *self.vars_
+ ),
+ )
+
+ ops_to_add = [
+ "Squeeze",
+ ]
+
+ for name in ops_to_add:
+ if hasattr(OpsVars, name):
+ continue
+ setattr(
+ OpsVars,
+ name,
+ lambda self, op_type=name: self.make_node(op_type, *self.vars_),
+ )
+
+
+_complete()
diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py
new file mode 100644
index 0000000..8d473fd
--- /dev/null
+++ b/onnx_array_api/light_api/annotations.py
@@ -0,0 +1,54 @@
+from typing import Tuple, Union
+import numpy as np
+from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto
+from onnx.helper import np_dtype_to_tensor_dtype
+
+NP_DTYPE = np.dtype
+ELEMENT_TYPE = Union[int, NP_DTYPE]
+SHAPE_TYPE = Tuple[int, ...]
+VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray]
+GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto]
+
+ELEMENT_TYPE_NAME = {
+ getattr(TensorProto, k): k
+ for k in dir(TensorProto)
+ if isinstance(getattr(TensorProto, k), int)
+}
+
+_type_numpy = {
+ np.float32: TensorProto.FLOAT,
+ np.float64: TensorProto.DOUBLE,
+ np.float16: TensorProto.FLOAT16,
+ np.int8: TensorProto.INT8,
+ np.int16: TensorProto.INT16,
+ np.int32: TensorProto.INT32,
+ np.int64: TensorProto.INT64,
+ np.uint8: TensorProto.UINT8,
+ np.uint16: TensorProto.UINT16,
+ np.uint32: TensorProto.UINT32,
+ np.uint64: TensorProto.UINT64,
+ np.bool_: TensorProto.BOOL,
+ np.str_: TensorProto.STRING,
+}
+
+
+def elem_type_int(elem_type: ELEMENT_TYPE) -> int:
+ """
+ Converts an element type into an onnx element type (int).
+
+ :param elem_type: integer or numpy type
+ :return: int
+ """
+ if isinstance(elem_type, int):
+ return elem_type
+ if elem_type in _type_numpy:
+ return _type_numpy[elem_type]
+ return np_dtype_to_tensor_dtype(elem_type)
+
+
+def make_shape(shape: TensorShapeProto) -> SHAPE_TYPE:
+ "Extracts a shape from a tensor type."
+ if hasattr(shape, "dims"):
+ res = [(d.dim_value if d.dim_value else d.dim_param) for d in shape.dims]
+ return tuple(res)
+ return None
diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py
new file mode 100644
index 0000000..def6cc1
--- /dev/null
+++ b/onnx_array_api/light_api/model.py
@@ -0,0 +1,352 @@
+from typing import Any, Dict, List, Optional, Union
+import numpy as np
+from onnx import NodeProto, SparseTensorProto, TensorProto, ValueInfoProto
+from onnx.checker import check_model
+from onnx.defs import onnx_opset_version
+from onnx.helper import (
+ make_graph,
+ make_model,
+ make_node,
+ make_opsetid,
+ make_tensor_value_info,
+ make_tensor_type_proto,
+)
+from onnx.numpy_helper import from_array
+from .annotations import (
+ elem_type_int,
+ make_shape,
+ GRAPH_PROTO,
+ ELEMENT_TYPE,
+ SHAPE_TYPE,
+ VAR_CONSTANT_TYPE,
+)
+
+
+class OnnxGraph:
+ """
+ Contains every piece needed to create an onnx model in a single instructions.
+ This API is meant to be light and allows the description of a graph.
+
+ :param opset: main opset version
+ :param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
+ :param opsets: others opsets as a dictionary
+ """
+
+ def __init__(
+ self,
+ opset: Optional[int] = None,
+ opsets: Optional[Dict[str, int]] = None,
+ is_function: bool = False,
+ ):
+ if opsets is not None and "" in opsets:
+ if opset is None:
+ opset = opsets[""]
+ elif opset != opsets[""]:
+ raise ValueError(
+ "The main opset can be specified twice with different values."
+ )
+ if is_function:
+ raise NotImplementedError(
+ "The first version of this API does not support functions."
+ )
+ self.is_function = is_function
+ self.opsets = opsets
+ self.opset = opset
+ self.nodes: List[Union[NodeProto, TensorProto]] = []
+ self.inputs: List[ValueInfoProto] = []
+ self.outputs: List[ValueInfoProto] = []
+ self.initializers: List[TensorProto] = []
+ self.unique_names_: Dict[str, Any] = {}
+ self.renames_: Dict[str, str] = {}
+
+ def __repr__(self) -> str:
+ "usual"
+ sts = [f"{self.__class__.__name__}("]
+ els = [
+ repr(getattr(self, o))
+ for o in ["opset", "opsets"]
+ if getattr(self, o) is not None
+ ]
+ if self.is_function:
+ els.append("is_function=True")
+ sts.append(", ".join(els))
+ sts.append(")")
+ return "".join(sts)
+
+ @property
+ def input_names(self) -> List[str]:
+ "Returns the input names"
+ return [v.name for v in self.inputs]
+
+ @property
+ def output_names(self) -> List[str]:
+ "Returns the output names"
+ return [v.name for v in self.outputs]
+
+ def has_name(self, name: str) -> bool:
+ "Tells if a name is already used."
+ return name in self.unique_names_
+
+ def unique_name(self, prefix="r", value: Optional[Any] = None) -> str:
+ """
+ Returns a unique name.
+
+ :param prefix: prefix
+ :param value: this name is mapped to this value
+ :return: unique name
+ """
+ name = prefix
+ i = len(self.unique_names_)
+ while name in self.unique_names_:
+ name = f"prefix{i}"
+ i += 1
+ self.unique_names_[name] = value
+ return name
+
+ def make_input(
+ self,
+ name: str,
+ elem_type: ELEMENT_TYPE = TensorProto.FLOAT,
+ shape: Optional[SHAPE_TYPE] = None,
+ ) -> ValueInfoProto:
+ """
+ Adds an input to the graph.
+
+ :param name: input name
+ :param elem_type: element type (the input is assumed to be a tensor)
+ :param shape: shape
+ :return: an instance of ValueInfoProto
+ """
+ if self.has_name(name):
+ raise ValueError(f"Name {name!r} is already taken.")
+ var = make_tensor_value_info(name, elem_type, shape)
+ self.inputs.append(var)
+ self.unique_names_[name] = var
+ return var
+
+ def vin(
+ self,
+ name: str,
+ elem_type: ELEMENT_TYPE = TensorProto.FLOAT,
+ shape: Optional[SHAPE_TYPE] = None,
+ ) -> "Var":
+ """
+ Declares a new input to the graph.
+
+ :param name: input name
+ :param elem_type: element_type
+ :param shape: shape
+ :return: instance of :class:`onnx_array_api.light_api.Var`
+ """
+ from .var import Var
+
+ proto = self.make_input(name, elem_type=elem_type_int(elem_type), shape=shape)
+ return Var(
+ self,
+ proto.name,
+ elem_type=proto.type.tensor_type.elem_type,
+ shape=make_shape(proto.type.tensor_type.shape),
+ )
+
+ def make_output(
+ self,
+ name: str,
+ elem_type: ELEMENT_TYPE = TensorProto.FLOAT,
+ shape: Optional[SHAPE_TYPE] = None,
+ ) -> ValueInfoProto:
+ """
+ Adds an output to the graph.
+
+ :param name: input name
+ :param elem_type: element type (the input is assumed to be a tensor)
+ :param shape: shape
+ :return: an instance of ValueInfoProto
+ """
+ if not self.has_name(name):
+ raise ValueError(f"Name {name!r} does not exist.")
+ var = make_tensor_value_info(name, elem_type_int(elem_type), shape)
+ self.outputs.append(var)
+ self.unique_names_[name] = var
+ return var
+
+ def make_constant(
+ self, value: np.ndarray, name: Optional[str] = None
+ ) -> TensorProto:
+ "Adds an initializer to the graph."
+ if self.is_function:
+ raise NotImplementedError(
+ "Adding a constant to a FunctionProto is not supported yet."
+ )
+ if isinstance(value, np.ndarray):
+ if name is None:
+ name = self.unique_name()
+ elif self.has_name(name):
+ raise RuntimeError(f"Name {name!r} already exists.")
+ tensor = from_array(value, name=name)
+ self.unique_names_[name] = tensor
+ self.initializers.append(tensor)
+ return tensor
+ raise TypeError(f"Unexpected type {type(value)} for constant {name!r}.")
+
+ def make_node(
+ self,
+ op_type: str,
+ *inputs: List[VAR_CONSTANT_TYPE],
+ domain: str = "",
+ n_outputs: int = 1,
+ output_names: Optional[List[str]] = None,
+ **kwargs: Dict[str, Any],
+ ) -> NodeProto:
+ """
+ Creates a node.
+
+ :param op_type: operator type
+ :param inputs: others inputs
+ :param domain: domain
+ :param n_outputs: number of outputs
+ :param output_names: output names, if not specified, outputs are given
+ unique names
+ :param kwargs: node attributes
+ :return: NodeProto
+ """
+ if output_names is None:
+ output_names = [self.unique_name(value=i) for i in range(n_outputs)]
+ elif n_outputs != len(output_names):
+ raise ValueError(
+ f"Expecting {n_outputs} outputs but received {output_names}."
+ )
+ input_names = []
+ for i in inputs:
+ if hasattr(i, "name"):
+ input_names.append(i.name)
+ elif isinstance(i, np.ndarray):
+ input_names.append(self.make_constant(i))
+ else:
+ raise TypeError(f"Unexpected type {type(i)} for one input.")
+
+ node = make_node(op_type, input_names, output_names, domain=domain, **kwargs)
+ self.nodes.append(node)
+ return node
+
+ def true_name(self, name: str) -> str:
+ """
+ Some names were renamed. If name is one of them, the function
+ returns the new name.
+ """
+ while name in self.renames_:
+ name = self.renames_[name]
+ return name
+
+ def get_var(self, name: str) -> "Var":
+ from .var import Var
+
+ tr = self.true_name(name)
+ proto = self.unique_names_[tr]
+ if isinstance(proto, ValueInfoProto):
+ return Var(
+ self,
+ proto.name,
+ elem_type=proto.type.tensor_type.elem_type,
+ shape=make_shape(proto.type.tensor_type.shape),
+ )
+ if isinstance(proto, TensorProto):
+ return Var(
+ self, proto.name, elem_type=proto.data_type, shape=tuple(proto.dims)
+ )
+ raise TypeError(f"Unexpected type {type(proto)} for name {name!r}.")
+
+ def rename(self, old_name: str, new_name: str):
+ """
+ Renames a variable. The renaming does not
+ change anything but is stored in a container.
+
+ :param old_name: old name
+ :param new_name: new name
+ """
+ if not self.has_name(old_name):
+ raise RuntimeError(f"Name {old_name!r} does not exist.")
+ if self.has_name(new_name):
+ raise RuntimeError(f"Name {old_name!r} already exist.")
+ self.unique_names_[new_name] = self.unique_names_[old_name]
+ self.renames_[old_name] = new_name
+
+ def _fix_name_tensor(
+ self, obj: Union[TensorProto, SparseTensorProto, ValueInfoProto]
+ ) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]:
+ true_name = self.true_name(obj.name)
+ if true_name != obj.name:
+ obj.name = true_name
+ return obj
+
+ def _fix_name_tensor_input(
+ self, obj: Union[TensorProto, SparseTensorProto, ValueInfoProto]
+ ) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]:
+ obj = self._fix_name_tensor(obj)
+ shape = make_shape(obj.type.tensor_type.shape)
+ if shape is None:
+ tensor_type_proto = make_tensor_type_proto(
+ obj.type.tensor_type.elem_type, []
+ )
+ obj.type.CopyFrom(tensor_type_proto)
+ return obj
+
+ def _fix_name_tensor_output(
+ self, obj: Union[TensorProto, SparseTensorProto, ValueInfoProto]
+ ) -> Union[TensorProto, SparseTensorProto, ValueInfoProto]:
+ obj = self._fix_name_tensor(obj)
+ shape = make_shape(obj.type.tensor_type.shape)
+ if shape is None:
+ tensor_type_proto = make_tensor_type_proto(
+ obj.type.tensor_type.elem_type, []
+ )
+ obj.type.CopyFrom(tensor_type_proto)
+ return obj
+
+ def _fix_name_node(self, obj: NodeProto) -> NodeProto:
+ new_inputs = [self.true_name(i) for i in obj.input]
+ if new_inputs != obj.input:
+ del obj.input[:]
+ obj.input.extend(new_inputs)
+ new_outputs = [self.true_name(o) for o in obj.output]
+ if new_outputs != obj.output:
+ del obj.output[:]
+ obj.output.extend(new_outputs)
+ return obj
+
+ def _check_input(self, i):
+ "Checks one input is fully specified."
+ if i.type.tensor_type.elem_type <= 0:
+ raise ValueError(f"Input {i.name!r} has no element type.")
+ return i
+
+ def to_onnx(self) -> GRAPH_PROTO:
+ """
+ Converts the graph into an ONNX graph.
+ """
+ if self.is_function:
+ raise NotImplementedError("Unable to convert a graph input ")
+ dense = [
+ self._fix_name_tensor(i)
+ for i in self.initializers
+ if isinstance(i, TensorProto)
+ ]
+ sparse = [
+ self._fix_name_tensor(i)
+ for i in self.initializers
+ if isinstance(i, SparseTensorProto)
+ ]
+ graph = make_graph(
+ [self._fix_name_node(n) for n in self.nodes],
+ "light_api",
+ [self._check_input(self._fix_name_tensor_input(i)) for i in self.inputs],
+ [self._fix_name_tensor_output(o) for o in self.outputs],
+ dense,
+ sparse,
+ )
+ opsets = [make_opsetid("", self.opset or onnx_opset_version() - 1)]
+ if self.opsets:
+ for k, v in self.opsets.items():
+ opsets.append(make_opsetid(k, v))
+ model = make_model(graph, opset_imports=opsets)
+ check_model(model)
+ return model
diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py
new file mode 100644
index 0000000..9fc9b85
--- /dev/null
+++ b/onnx_array_api/light_api/var.py
@@ -0,0 +1,300 @@
+from typing import Any, Dict, List, Optional, Union
+import numpy as np
+from onnx import TensorProto
+from .annotations import (
+ elem_type_int,
+ make_shape,
+ ELEMENT_TYPE,
+ ELEMENT_TYPE_NAME,
+ GRAPH_PROTO,
+ SHAPE_TYPE,
+ VAR_CONSTANT_TYPE,
+)
+from .model import OnnxGraph
+from ._op_var import OpsVar
+from ._op_vars import OpsVars
+
+
+class BaseVar:
+ """
+ Represents an input, an initializer, a node, an output,
+ multiple variables.
+
+ :param parent: the graph containing the Variable
+ """
+
+ def __init__(
+ self,
+ parent: OnnxGraph,
+ ):
+ self.parent = parent
+
+ def make_node(
+ self,
+ op_type: str,
+ *inputs: List[VAR_CONSTANT_TYPE],
+ domain: str = "",
+ n_outputs: int = 1,
+ output_names: Optional[List[str]] = None,
+ **kwargs: Dict[str, Any],
+ ) -> Union["Var", "Vars"]:
+ """
+ Creates a node with this Var as the first input.
+
+ :param op_type: operator type
+ :param inputs: others inputs
+ :param domain: domain
+ :param n_outputs: number of outputs
+ :param output_names: output names, if not specified, outputs are given
+ unique names
+ :param kwargs: node attributes
+ :return: instance of :class:`onnx_array_api.light_api.Var` or
+ :class:`onnx_array_api.light_api.Vars`
+ """
+ node_proto = self.parent.make_node(
+ op_type,
+ *inputs,
+ domain=domain,
+ n_outputs=n_outputs,
+ output_names=output_names,
+ **kwargs,
+ )
+ names = node_proto.output
+ if len(names) == 1:
+ return Var(self.parent, names[0])
+ return Vars(*map(lambda v: Var(self.parent, v), names))
+
+ def vin(
+ self,
+ name: str,
+ elem_type: ELEMENT_TYPE = TensorProto.FLOAT,
+ shape: Optional[SHAPE_TYPE] = None,
+ ) -> "Var":
+ """
+ Declares a new input to the graph.
+
+ :param name: input name
+ :param elem_type: element_type
+ :param shape: shape
+ :return: instance of :class:`onnx_array_api.light_api.Var`
+ """
+ return self.parent.vin(name, elem_type=elem_type, shape=shape)
+
+ def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var":
+ """
+ Adds an initializer
+
+ :param value: constant tensor
+ :param name: input name
+ :return: instance of :class:`onnx_array_api.light_api.Var`
+ """
+ c = self.parent.make_constant(value, name=name)
+ return Var(self.parent, c.name, elem_type=c.data_type, shape=tuple(c.dims))
+
+ def vout(
+ self,
+ elem_type: ELEMENT_TYPE = TensorProto.FLOAT,
+ shape: Optional[SHAPE_TYPE] = None,
+ ) -> "Var":
+ """
+ Declares a new output to the graph.
+
+ :param elem_type: element_type
+ :param shape: shape
+ :return: instance of :class:`onnx_array_api.light_api.Var`
+ """
+ output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape)
+ return Var(
+ self.parent,
+ output,
+ elem_type=output.type.tensor_type.elem_type,
+ shape=make_shape(output.type.tensor_type.shape),
+ )
+
+ def v(self, name: str) -> "Var":
+ """
+ Retrieves another variable than this one.
+
+ :param name: name of the variable
+ :return: instance of :class:`onnx_array_api.light_api.Var`
+ """
+ return self.parent.get_var(name)
+
+ def bring(self, *vars: List[Union[str, "Var"]]) -> "Vars":
+ """
+ Creates a set of variable as an instance of
+ :class:`onnx_array_api.light_api.Vars`.
+ """
+ return Vars(self.parent, *vars)
+
+ def left_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars":
+ """
+ Creates a set of variables as an instance of
+ :class:`onnx_array_api.light_api.Vars`.
+ `*vars` is added to the left, `self` is added to the right.
+ """
+ vs = [*vars, self]
+ return Vars(self.parent, *vs)
+
+ def right_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars":
+ """
+ Creates a set of variables as an instance of
+ :class:`onnx_array_api.light_api.Vars`.
+ `*vars` is added to the right, `self` is added to the left.
+ """
+ vs = [self, *vars]
+ return Vars(self.parent, *vs)
+
+ def to_onnx(self) -> GRAPH_PROTO:
+ "Creates the onnx graph."
+ return self.parent.to_onnx()
+
+
+class Var(BaseVar, OpsVar):
+ """
+ Represents an input, an initializer, a node, an output.
+
+ :param parent: graph the variable belongs to
+ :param name: input name
+ :param elem_type: element_type
+ :param shape: shape
+ """
+
+ def __init__(
+ self,
+ parent: OnnxGraph,
+ name: str,
+ elem_type: Optional[ELEMENT_TYPE] = 1,
+ shape: Optional[SHAPE_TYPE] = None,
+ ):
+ BaseVar.__init__(self, parent)
+ self.name_ = name
+ self.elem_type = elem_type
+ self.shape = shape
+
+ @property
+ def name(self):
+ "Returns the name of the variable or the new name if it was renamed."
+ return self.parent.true_name(self.name_)
+
+ def __str__(self) -> str:
+ "usual"
+ s = f"{self.name}"
+ if self.elem_type is None:
+ return s
+ s = f"{s}:{ELEMENT_TYPE_NAME[self.elem_type]}"
+ if self.shape is None:
+ return s
+ return f"{s}:[{''.join(map(str, self.shape))}]"
+
+ def rename(self, new_name: str) -> "Var":
+ "Renames a variable."
+ self.parent.rename(self.name, new_name)
+ return self
+
+ def to(self, to: ELEMENT_TYPE) -> "Var":
+ "Casts a tensor into another element type."
+ return self.Cast(to=elem_type_int(to))
+
+ def astype(self, to: ELEMENT_TYPE) -> "Var":
+ "Casts a tensor into another element type."
+ return self.Cast(to=elem_type_int(to))
+
+ def reshape(self, new_shape: VAR_CONSTANT_TYPE) -> "Var":
+ "Reshapes a variable."
+ if isinstance(new_shape, tuple):
+ cst = self.cst(np.array(new_shape, dtype=np.int64))
+ return self.bring(self, cst).Reshape()
+ return self.bring(self, new_shape).Reshape()
+
+ def __add__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Add()
+
+ def __eq__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Equal()
+
+ def __float__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Cast(to=TensorProto.FLOAT)
+
+ def __gt__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Greater()
+
+ def __ge__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).GreaterOrEqual()
+
+ def __int__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Cast(to=TensorProto.INT64)
+
+ def __lt__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Less()
+
+ def __le__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).LessOrEqual()
+
+ def __matmul__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).MatMul()
+
+ def __mod__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Mod()
+
+ def __mul__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Mul()
+
+ def __ne__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Equal().Not()
+
+ def __neg__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.Neg()
+
+ def __pow__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Pow()
+
+ def __sub__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Sub()
+
+ def __truediv__(self, var: VAR_CONSTANT_TYPE) -> "Var":
+ "Intuitive."
+ return self.bring(self, var).Div()
+
+
+class Vars(BaseVar, OpsVars):
+ """
+ Represents multiple Var.
+
+ :param parent: graph the variable belongs to
+ :param vars: list of names or variables
+ """
+
+ def __init__(self, parent, *vars: List[Union[str, Var]]):
+ BaseVar.__init__(self, parent)
+ self.vars_ = []
+ for v in vars:
+ if isinstance(v, str):
+ var = self.parent.get_var(v)
+ else:
+ var = v
+ self.vars_.append(var)
+
+ def __len__(self):
+ "Returns the number of variables."
+ return len(self.vars_)
+
+ def _check_nin(self, n_inputs):
+ if len(self) != n_inputs:
+ raise RuntimeError(f"Expecting {n_inputs} inputs not {len(self)}.")
+ return self
diff --git a/onnx_array_api/npx/npx_function_implementation.py b/onnx_array_api/npx/npx_function_implementation.py
index db9233f..b536888 100644
--- a/onnx_array_api/npx/npx_function_implementation.py
+++ b/onnx_array_api/npx/npx_function_implementation.py
@@ -14,7 +14,7 @@ def get_function_implementation(
**kwargs: Any,
) -> FunctionProto:
"""
- Returns a :epkg:`FunctionProto` for a specific proto.
+ Returns a :class:`onnx.FunctionProto` for a specific proto.
:param domop: domain, function
:param node_inputs: list of input names
diff --git a/onnx_array_api/npx/npx_helper.py b/onnx_array_api/npx/npx_helper.py
index b49ab02..f86aadc 100644
--- a/onnx_array_api/npx/npx_helper.py
+++ b/onnx_array_api/npx/npx_helper.py
@@ -19,9 +19,9 @@ def rename_in_onnx_graph(
"""
Renames input results in a GraphProto.
- :param graph: :epkg:`GraphProto`
+ :param graph: :class:`onnx.GraphProto`
:param replacements: replacements `{ old_name: new_name }`
- :return: modified :epkg:`GraphProto` or None if no modifications
+ :return: modified :class:`onnx.GraphProto` or None if no modifications
were detected
"""
@@ -153,8 +153,9 @@ def onnx_model_to_function(
:param inputs2par: dictionary to move some inputs as attributes
`{ name: None or default value }`
:return: function, other functions
+
.. warning::
- :epkg:`FunctionProto` does not support default values yet.
+ :class:`onnx.FunctionProto` does not support default values yet.
They are ignored.
"""
if isinstance(onx, ModelProto):
diff --git a/pyproject.toml b/pyproject.toml
index 7e15de0..3e85f19 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,11 @@ max-complexity = 10
"_doc/examples/plot_first_example.py" = ["E402", "F811"]
"_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]
"onnx_array_api/array_api/_onnx_common.py" = ["F821"]
+"onnx_array_api/light_api/__init__.py" = ["F401"]
+"onnx_array_api/light_api/_op_var.py" = ["F821"]
+"onnx_array_api/light_api/_op_vars.py" = ["F821"]
+"onnx_array_api/light_api/annotations.py" = ["F821"]
+"onnx_array_api/light_api/model.py" = ["F821"]
"onnx_array_api/npx/__init__.py" = ["F401", "F403"]
"onnx_array_api/npx/npx_functions.py" = ["F821"]
"onnx_array_api/npx/npx_functions_test.py" = ["F821"]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index a65403b..c34f54c 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -3,6 +3,7 @@ black
coverage
flake8
furo
+google-re2
hypothesis
isort
joblib
@@ -14,6 +15,7 @@ onnxruntime
openpyxl
packaging
pandas
+Pillow
psutil
pytest
pytest-cov