diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 6b22ae9..0483354 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -2,7 +2,7 @@ import unittest from typing import Callable, Optional import numpy as np -from onnx import GraphProto, ModelProto +from onnx import GraphProto, ModelProto, TensorProto from onnx.defs import ( get_all_schemas_with_history, onnx_opset_version, @@ -526,6 +526,18 @@ def test_input_shape(self): i = str(model.graph.input[0]).replace("\n", "").replace(" ", "") self.assertNotIn("shape{}", i) + def test_constant_of_shape(self): + onx = ( + start() + .vin("X", TensorProto.INT64, shape=[None, None]) + .ConstantOfShape() + .vout(shape=[]) + .to_onnx() + ) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0] + self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got) + if __name__ == "__main__": TestLightApi().test_add() diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 3fe9489..83e8878 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -8,12 +8,14 @@ def start( opset: Optional[int] = None, opsets: Optional[Dict[str, int]] = None, + ir_version: Optional[int] = None, ) -> OnnxGraph: """ Starts an onnx model. :param opset: main opset version :param opsets: others opsets as a dictionary + :param ir_version: specify the ir_version as well :return: an instance of :class:`onnx_array_api.light_api.OnnxGraph` A very simple model: @@ -45,7 +47,7 @@ def start( ) print(onx) """ - return OnnxGraph(opset=opset, opsets=opsets) + return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version) def g() -> OnnxGraph: diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index 27a04d1..3a74ed2 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -1,4 +1,6 @@ from typing import List, Optional, Union +import numpy as np +from ..reference import from_array_extended from ..annotations import AI_ONNX_ML, domain @@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var": def Celu(self, alpha: float = 1.0) -> "Var": return self.make_node("Celu", self, alpha=alpha) + def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var": + if value is None: + return self.make_node("ConstantOfShape", self) + return self.make_node("ConstantOfShape", self, value=from_array_extended(value)) + def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var": return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode) diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 5a7eef5..25194ac 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -42,6 +42,7 @@ class OnnxGraph: :param opset: main opset version :param opsets: other opsets as a dictionary + :param ir_version: to specify an ir_version :param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto` """ @@ -49,6 +50,7 @@ def __init__( self, opset: Optional[int] = None, opsets: Optional[Dict[str, int]] = None, + ir_version: Optional[int] = None, proto_type: ProtoType = ProtoType.MODEL, ): if opsets is not None and "" in opsets: @@ -65,6 +67,7 @@ def __init__( self.proto_type = proto_type self.opsets = opsets self.opset = opset + self.ir_version = ir_version self.nodes: List[Union[NodeProto, TensorProto]] = [] self.inputs: List[ValueInfoProto] = [] self.outputs: List[ValueInfoProto] = [] @@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO: # If no opsets, it a subgraph, not a model. return graph model = make_model(graph, opset_imports=opsets) + if self.ir_version: + model.ir_version = ir_version if not is_windows() or not is_azure(): # check_model fails sometimes on Windows check_model(model)