Skip to content

Commit 032aff5

Browse files
committed
2 parents 4c12efd + a906010 commit 032aff5

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`77`: supports ConcatOfShape and Slice with the light API
78
* :pr:`76`: add a mode to compare models without execution
89
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
910
* :pr:`71`: adds tools to compare two onnx graphs

_unittests/ut_light_api/test_light_api.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
from typing import Callable, Optional
44
import numpy as np
5-
from onnx import GraphProto, ModelProto
5+
from onnx import GraphProto, ModelProto, TensorProto
66
from onnx.defs import (
77
get_all_schemas_with_history,
88
onnx_opset_version,
@@ -526,7 +526,47 @@ def test_input_shape(self):
526526
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
527527
self.assertNotIn("shape{}", i)
528528

529+
def test_constant_of_shape(self):
530+
onx = (
531+
start()
532+
.vin("X", TensorProto.INT64, shape=[None, None])
533+
.ConstantOfShape()
534+
.vout(shape=[])
535+
.to_onnx()
536+
)
537+
ref = ReferenceEvaluator(onx)
538+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
539+
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
540+
541+
def test_constant_of_shape_value(self):
542+
onx = (
543+
start()
544+
.vin("X", TensorProto.INT64, shape=[None, None])
545+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
546+
.vout(shape=[])
547+
.to_onnx()
548+
)
549+
ref = ReferenceEvaluator(onx)
550+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
551+
self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)
552+
553+
def test_slice(self):
554+
onx = (
555+
start(opset=18, ir_version=9)
556+
.cst(np.array([1], dtype=np.int64), name="one")
557+
.cst(np.array([2], dtype=np.int64), name="two")
558+
.vin("X", TensorProto.INT64, shape=[None, None])
559+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
560+
.rename("CX")
561+
.bring("CX", "one", "two", "one")
562+
.Slice()
563+
.vout(shape=[])
564+
.to_onnx()
565+
)
566+
ref = ReferenceEvaluator(onx)
567+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
568+
self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)
569+
529570

530571
if __name__ == "__main__":
531-
TestLightApi().test_add()
532572
unittest.main(verbosity=2)

onnx_array_api/light_api/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
def start(
99
opset: Optional[int] = None,
1010
opsets: Optional[Dict[str, int]] = None,
11+
ir_version: Optional[int] = None,
1112
) -> OnnxGraph:
1213
"""
1314
Starts an onnx model.
1415
1516
:param opset: main opset version
1617
:param opsets: others opsets as a dictionary
18+
:param ir_version: specify the ir_version as well
1719
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
1820
1921
A very simple model:
@@ -45,7 +47,7 @@ def start(
4547
)
4648
print(onx)
4749
"""
48-
return OnnxGraph(opset=opset, opsets=opsets)
50+
return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version)
4951

5052

5153
def g() -> OnnxGraph:

onnx_array_api/light_api/_op_var.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import List, Optional, Union
2+
import numpy as np
3+
from ..reference import from_array_extended
24
from ..annotations import AI_ONNX_ML, domain
35

46

@@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
6971
def Celu(self, alpha: float = 1.0) -> "Var":
7072
return self.make_node("Celu", self, alpha=alpha)
7173

74+
def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var":
75+
if value is None:
76+
return self.make_node("ConstantOfShape", self)
77+
return self.make_node("ConstantOfShape", self, value=from_array_extended(value))
78+
7279
def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var":
7380
return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)
7481

@@ -307,6 +314,13 @@ def Selu(
307314
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
308315
return self.make_node("Shrink", self, bias=bias, lambd=lambd)
309316

317+
def Slice(
318+
self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
319+
) -> "Var":
320+
if steps is None:
321+
return self.make_node("Slice", self, starts, ends, axes)
322+
return self.make_node("Slice", self, starts, ends, axes, steps)
323+
310324
def Softmax(self, axis: int = -1) -> "Var":
311325
return self.make_node("Softmax", self, axis=axis)
312326

onnx_array_api/light_api/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ class OnnxGraph:
4242
4343
:param opset: main opset version
4444
:param opsets: other opsets as a dictionary
45+
:param ir_version: to specify an ir_version
4546
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
4647
"""
4748

4849
def __init__(
4950
self,
5051
opset: Optional[int] = None,
5152
opsets: Optional[Dict[str, int]] = None,
53+
ir_version: Optional[int] = None,
5254
proto_type: ProtoType = ProtoType.MODEL,
5355
):
5456
if opsets is not None and "" in opsets:
@@ -65,6 +67,7 @@ def __init__(
6567
self.proto_type = proto_type
6668
self.opsets = opsets
6769
self.opset = opset
70+
self.ir_version = ir_version
6871
self.nodes: List[Union[NodeProto, TensorProto]] = []
6972
self.inputs: List[ValueInfoProto] = []
7073
self.outputs: List[ValueInfoProto] = []
@@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO:
402405
# If no opsets, it a subgraph, not a model.
403406
return graph
404407
model = make_model(graph, opset_imports=opsets)
408+
if self.ir_version:
409+
model.ir_version = self.ir_version
405410
if not is_windows() or not is_azure():
406411
# check_model fails sometimes on Windows
407412
check_model(model)

0 commit comments

Comments
 (0)