diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 706cfed..d382b74 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,11 @@ Change Logs =========== +0.1.3 ++++++ + +* :pr:`45`: fixes light API for operators with two outputs + 0.1.2 +++++ diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 3feaa2a..f99a4b5 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -402,6 +402,45 @@ def test_operator_bool(self): got = ref.run(None, {"X": a, "Y": b})[0] self.assertEqualArray(f(a, b), got) + def test_topk(self): + onx = ( + start() + .vin("X", np.float32) + .vin("K", np.int64) + .bring("X", "K") + .TopK() + .rename("Values", "Indices") + .vout() + .to_onnx() + ) + self.assertIsInstance(onx, ModelProto) + ref = ReferenceEvaluator(onx) + x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32) + k = np.array([2], dtype=np.int64) + got = ref.run(None, {"X": x, "K": k}) + self.assertEqualArray(np.array([[3, 2], [9, 8]], dtype=np.float32), got[0]) + self.assertEqualArray(np.array([[3, 2], [0, 1]], dtype=np.int64), got[1]) + + def test_topk_reverse(self): + onx = ( + start() + .vin("X", np.float32) + .vin("K", np.int64) + .bring("X", "K") + .TopK(largest=0) + .rename("Values", "Indices") + .vout() + .to_onnx() + ) + self.assertIsInstance(onx, ModelProto) + ref = ReferenceEvaluator(onx) + x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32) + k = np.array([2], dtype=np.int64) + got = ref.run(None, {"X": x, "K": k}) + self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0]) + self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1]) + if __name__ == "__main__": + # TestLightApi().test_topk() unittest.main(verbosity=2) diff --git a/onnx_array_api/__init__.py b/onnx_array_api/__init__.py index b2a711d..09a2edd 100644 --- a/onnx_array_api/__init__.py +++ b/onnx_array_api/__init__.py @@ -3,5 +3,5 @@ APIs to create ONNX Graphs. """ -__version__ = "0.1.2" +__version__ = "0.1.3" __author__ = "Xavier Dupré" diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index e2354eb..8b6b651 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -30,7 +30,7 @@ def ArgMin( def AveragePool( self, - auto_pad: str = b"NOTSET", + auto_pad: str = "NOTSET", ceil_mode: int = 0, count_include_pad: int = 0, dilations: Optional[List[int]] = None, @@ -68,7 +68,7 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var": def Celu(self, alpha: float = 1.0) -> "Var": return self.make_node("Celu", self, alpha=alpha) - def DepthToSpace(self, blocksize: int = 0, mode: str = b"DCR") -> "Var": + def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var": return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode) def DynamicQuantizeLinear( @@ -137,7 +137,7 @@ def LpNormalization(self, axis: int = -1, p: int = 2) -> "Var": def LpPool( self, - auto_pad: str = b"NOTSET", + auto_pad: str = "NOTSET", ceil_mode: int = 0, dilations: Optional[List[int]] = None, kernel_shape: Optional[List[int]] = None, diff --git a/onnx_array_api/light_api/_op_vars.py b/onnx_array_api/light_api/_op_vars.py index 77dbac6..f4dee1c 100644 --- a/onnx_array_api/light_api/_op_vars.py +++ b/onnx_array_api/light_api/_op_vars.py @@ -6,7 +6,7 @@ class OpsVars: Operators taking multiple inputs. """ - def BitShift(self, direction: str = b"") -> "Var": + def BitShift(self, direction: str = "") -> "Var": return self.make_node("BitShift", *self.vars_, direction=direction) def CenterCropPad(self, axes: Optional[List[int]] = None) -> "Var": @@ -42,7 +42,7 @@ def Concat(self, axis: int = 0) -> "Var": def Conv( self, - auto_pad: str = b"NOTSET", + auto_pad: str = "NOTSET", dilations: Optional[List[int]] = None, group: int = 1, kernel_shape: Optional[List[int]] = None, @@ -66,7 +66,7 @@ def Conv( def ConvInteger( self, - auto_pad: str = b"NOTSET", + auto_pad: str = "NOTSET", dilations: Optional[List[int]] = None, group: int = 1, kernel_shape: Optional[List[int]] = None, @@ -90,7 +90,7 @@ def ConvInteger( def ConvTranspose( self, - auto_pad: str = b"NOTSET", + auto_pad: str = "NOTSET", dilations: Optional[List[int]] = None, group: int = 1, kernel_shape: Optional[List[int]] = None, @@ -155,7 +155,7 @@ def DeformConv( def DequantizeLinear(self, axis: int = 1) -> "Var": return self.make_node("DequantizeLinear", *self.vars_, axis=axis) - def Einsum(self, equation: str = b"") -> "Var": + def Einsum(self, equation: str = "") -> "Var": return self.make_node("Einsum", *self.vars_, equation=equation) def Gather(self, axis: int = 0) -> "Var": @@ -174,8 +174,8 @@ def Gemm( def GridSample( self, align_corners: int = 0, - mode: str = b"bilinear", - padding_mode: str = b"zeros", + mode: str = "bilinear", + padding_mode: str = "zeros", ) -> "Var": return self.make_node( "GridSample", @@ -240,7 +240,7 @@ def Mod(self, fmod: int = 0) -> "Var": return self.make_node("Mod", *self.vars_, fmod=fmod) def NegativeLogLikelihoodLoss( - self, ignore_index: int = 0, reduction: str = b"mean" + self, ignore_index: int = 0, reduction: str = "mean" ) -> "Var": return self.make_node( "NegativeLogLikelihoodLoss", @@ -257,12 +257,12 @@ def NonMaxSuppression(self, center_point_box: int = 0) -> "Var": def OneHot(self, axis: int = -1) -> "Var": return self.make_node("OneHot", *self.vars_, axis=axis) - def Pad(self, mode: str = b"constant") -> "Var": + def Pad(self, mode: str = "constant") -> "Var": return self.make_node("Pad", *self.vars_, mode=mode) def QLinearConv( self, - auto_pad: str = b"NOTSET", + auto_pad: str = "NOTSET", dilations: Optional[List[int]] = None, group: int = 1, kernel_shape: Optional[List[int]] = None, @@ -431,13 +431,13 @@ def Resize( self, antialias: int = 0, axes: Optional[List[int]] = None, - coordinate_transformation_mode: str = b"half_pixel", + coordinate_transformation_mode: str = "half_pixel", cubic_coeff_a: float = -0.75, exclude_outside: int = 0, extrapolation_value: float = 0.0, - keep_aspect_ratio_policy: str = b"stretch", - mode: str = b"nearest", - nearest_mode: str = b"round_prefer_floor", + keep_aspect_ratio_policy: str = "stretch", + mode: str = "nearest", + nearest_mode: str = "round_prefer_floor", ) -> "Var": axes = axes or [] return self.make_node( @@ -456,8 +456,8 @@ def Resize( def RoiAlign( self, - coordinate_transformation_mode: str = b"half_pixel", - mode: str = b"avg", + coordinate_transformation_mode: str = "half_pixel", + mode: str = "avg", output_height: int = 1, output_width: int = 1, sampling_ratio: int = 0, @@ -480,12 +480,12 @@ def STFT(self, onesided: int = 1) -> "Var": def Scatter(self, axis: int = 0) -> "Var": return self.make_node("Scatter", *self.vars_, axis=axis) - def ScatterElements(self, axis: int = 0, reduction: str = b"none") -> "Var": + def ScatterElements(self, axis: int = 0, reduction: str = "none") -> "Var": return self.make_node( "ScatterElements", *self.vars_, axis=axis, reduction=reduction ) - def ScatterND(self, reduction: str = b"none") -> "Var": + def ScatterND(self, reduction: str = "none") -> "Var": return self.make_node("ScatterND", *self.vars_, reduction=reduction) def Slice( @@ -498,13 +498,18 @@ def Slice( def TopK(self, axis: int = -1, largest: int = 1, sorted: int = 1) -> "Vars": return self.make_node( - "TopK", *self.vars_, axis=axis, largest=largest, sorted=sorted + "TopK", + *self.vars_, + axis=axis, + largest=largest, + sorted=sorted, + n_outputs=2, ) def Trilu(self, upper: int = 1) -> "Var": return self.make_node("Trilu", *self.vars_, upper=upper) - def Upsample(self, mode: str = b"nearest") -> "Var": + def Upsample(self, mode: str = "nearest") -> "Var": return self.make_node("Upsample", *self.vars_, mode=mode) def Where( diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 090e29c..d88be3a 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -28,8 +28,8 @@ class OnnxGraph: This API is meant to be light and allows the description of a graph. :param opset: main opset version + :param opsets: other opsets as a dictionary :param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto` - :param opsets: others opsets as a dictionary """ def __init__( diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index 6da1ee3..2c8b375 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from onnx import TensorProto +from onnx.defs import get_schema from .annotations import ( elem_type_int, make_shape, @@ -27,6 +28,8 @@ def __init__( self, parent: OnnxGraph, ): + if not isinstance(parent, OnnxGraph): + raise RuntimeError(f"Unexpected parent type {type(parent)}.") self.parent = parent def make_node( @@ -51,6 +54,27 @@ def make_node( :return: instance of :class:`onnx_array_api.light_api.Var` or :class:`onnx_array_api.light_api.Vars` """ + if domain in ("", "ai.onnx.ml"): + if self.parent.opset is None: + schema = get_schema(op_type, domain) + else: + schema = get_schema(op_type, self.parent.opset, domain) + if n_outputs < schema.min_output or n_outputs > schema.max_output: + raise RuntimeError( + f"Unexpected number of outputs ({n_outputs}) " + f"for node type {op_type!r}, domain={domain!r}, " + f"version={self.parent.opset}, it should be in " + f"[{schema.min_output}, {schema.max_output}]." + ) + n_inputs = len(inputs) + if n_inputs < schema.min_input or n_inputs > schema.max_input: + raise RuntimeError( + f"Unexpected number of inputs ({n_inputs}) " + f"for node type {op_type!r}, domain={domain!r}, " + f"version={self.parent.opset}, it should be in " + f"[{schema.min_input}, {schema.max_input}]." + ) + node_proto = self.parent.make_node( op_type, *inputs, @@ -60,9 +84,13 @@ def make_node( **kwargs, ) names = node_proto.output + if n_outputs is not None and len(node_proto.output) != len(names): + raise RuntimeError( + f"Expects {n_outputs} outputs but output names are {names}." + ) if len(names) == 1: return Var(self.parent, names[0]) - return Vars(*map(lambda v: Var(self.parent, v), names)) + return Vars(self.parent, *list(map(lambda v: Var(self.parent, v), names))) def vin( self, @@ -91,26 +119,6 @@ def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var": c = self.parent.make_constant(value, name=name) return Var(self.parent, c.name, elem_type=c.data_type, shape=tuple(c.dims)) - def vout( - self, - elem_type: ELEMENT_TYPE = TensorProto.FLOAT, - shape: Optional[SHAPE_TYPE] = None, - ) -> "Var": - """ - Declares a new output to the graph. - - :param elem_type: element_type - :param shape: shape - :return: instance of :class:`onnx_array_api.light_api.Var` - """ - output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape) - return Var( - self.parent, - output, - elem_type=output.type.tensor_type.elem_type, - shape=make_shape(output.type.tensor_type.shape), - ) - def v(self, name: str) -> "Var": """ Retrieves another variable than this one. @@ -127,6 +135,13 @@ def bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": """ return Vars(self.parent, *vars) + def vout(self, **kwargs: Dict[str, Any]) -> Union["Var", "Vars"]: + """ + This method needs to be overwritten for Var and Vars depending + on the number of variable to declare as outputs. + """ + raise RuntimeError(f"The method was not overwritten in class {type(self)}.") + def left_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": """ Creates a set of variables as an instance of @@ -187,6 +202,26 @@ def __str__(self) -> str: return s return f"{s}:[{''.join(map(str, self.shape))}]" + def vout( + self, + elem_type: ELEMENT_TYPE = TensorProto.FLOAT, + shape: Optional[SHAPE_TYPE] = None, + ) -> "Var": + """ + Declares a new output to the graph. + + :param elem_type: element_type + :param shape: shape + :return: instance of :class:`onnx_array_api.light_api.Var` + """ + output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape) + return Var( + self.parent, + output, + elem_type=output.type.tensor_type.elem_type, + shape=make_shape(output.type.tensor_type.shape), + ) + def rename(self, new_name: str) -> "Var": "Renames a variable." self.parent.rename(self.name, new_name) @@ -299,6 +334,39 @@ def _check_nin(self, n_inputs): raise RuntimeError(f"Expecting {n_inputs} inputs not {len(self)}.") return self - def rename(self, new_name: str) -> "Var": + def rename(self, *new_names: List[str]) -> "Vars": "Renames variables." - raise NotImplementedError("Not yet implemented.") + if len(new_names) != len(self): + raise ValueError( + f"Vars has {len(self)} elements but the method received {len(new_names)} names." + ) + new_vars = [] + for var, name in zip(self.vars_, new_names): + new_vars.append(var.rename(name)) + return Vars(self.parent, *new_names) + + def vout( + self, + *elem_type_shape: List[ + Union[ELEMENT_TYPE, Tuple[ELEMENT_TYPE, Optional[SHAPE_TYPE]]] + ], + ) -> "Vars": + """ + Declares a new output to the graph. + + :param elem_type_shape: list of tuple(element_type, shape) + :return: instance of :class:`onnx_array_api.light_api.Vars` + """ + vars = [] + for i, v in enumerate(self.vars_): + if i < len(elem_type_shape): + if isinstance(elem_type_shape[i]) or len(elem_type_shape[i]) < 2: + elem_type = elem_type_shape[i][0] + shape = None + else: + elem_type, shape = elem_type_shape[i] + else: + elem_type = TensorProto.FLOAT + shape = None + vars.append(v.vout(elem_type=elem_type, shape=shape)) + return Vars(self.parent, *vars)