Skip to content

Commit 432fa69

Browse files
committed
ir
1 parent 722fd2a commit 432fa69

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def setUp(self):
2020
self.maxDiff = None
2121

2222
def test_exp(self):
23-
onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx()
23+
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
2424
self.assertIsInstance(onx, ModelProto)
2525
self.assertIn("Exp", str(onx))
2626
ref = ReferenceEvaluator(onx)
@@ -39,7 +39,7 @@ def light_api(
3939
op.Identity(Y, outputs=["Y"])
4040
return Y
4141
42-
g = GraphBuilder({'': 19}, ir_version=11)
42+
g = GraphBuilder({'': 19}, ir_version=10)
4343
g.make_tensor_input("X", TensorProto.FLOAT, ())
4444
light_api(g.op, "X")
4545
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -69,7 +69,7 @@ def light_api(
6969

7070
def test_zdoc(self):
7171
onx = (
72-
start(opset=19, ir_version=11)
72+
start(opset=19, ir_version=10)
7373
.vin("X")
7474
.reshape((-1, 1))
7575
.Transpose(perm=[1, 0])
@@ -90,7 +90,7 @@ def light_api(
9090
op.Identity(Y, outputs=["Y"])
9191
return Y
9292
93-
g = GraphBuilder({'': 19}, ir_version=11)
93+
g = GraphBuilder({'': 19}, ir_version=10)
9494
g.make_tensor_input("X", TensorProto.FLOAT, ())
9595
light_api(g.op, "X")
9696
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -119,7 +119,7 @@ def light_api(
119119
check_model(model)
120120

121121
def test_exp_f(self):
122-
onx = start(opset=19, ir_version=11).vin("X").Exp().rename("Y").vout().to_onnx()
122+
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
123123
self.assertIsInstance(onx, ModelProto)
124124
self.assertIn("Exp", str(onx))
125125
ref = ReferenceEvaluator(onx)
@@ -142,7 +142,7 @@ def light_api(
142142
143143
144144
def mm() -> "ModelProto":
145-
g = GraphBuilder({'': 19}, ir_version=11)
145+
g = GraphBuilder({'': 19}, ir_version=10)
146146
g.make_tensor_input("X", TensorProto.FLOAT, ())
147147
light_api(g.op, "X")
148148
g.make_tensor_output("Y", TensorProto.FLOAT, ())

onnx_array_api/translate_api/builder_emitter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
148148
if kwargs.get("domain", "") != "":
149149
domain = kwargs["domain"]
150150
op_type = f"{domain}.{op_type}"
151+
else:
152+
domain = ""
151153
atts = kwargs.get("atts", {})
152154
args = []
153155
for k, v in atts.items():
@@ -158,9 +160,14 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
158160

159161
outs = ", ".join(outputs)
160162
inps = ", ".join(inputs)
163+
op_type = self._emit_node_type(op_type, domain)
164+
sdomain = "" if not domain else f", domain={domain!r}"
161165
if args:
162166
sargs = ", ".join(args)
163-
row = f" {outs} = op.{op_type}({inps}, {sargs})"
167+
row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})"
164168
else:
165-
row = f" {outs} = op.{op_type}({inps})"
169+
row = f" {outs} = op.{op_type}({inps}{sdomain})"
166170
return [row]
171+
172+
def _emit_node_type(self, op_type, domain):
173+
return op_type

0 commit comments

Comments
 (0)