Skip to content

Commit b11db3c

Browse files
committed
Export evaluator type in compare_onnx_execution
1 parent 07c3683 commit b11db3c

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

LICENSE.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2023-2024, Xavier Dupré
1+
Copyright (c) 2023-2025, Xavier Dupré
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

onnx_array_api/reference/evaluator_yield.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from enum import IntEnum
44
import numpy as np
55
from onnx import ModelProto, TensorProto, ValueInfoProto, load
6+
from onnx.reference import ReferenceEvaluator
67
from onnx.helper import tensor_dtype_to_np_dtype
78
from onnx.shape_inference import infer_shapes
89
from . import to_array_extended
@@ -166,9 +167,9 @@ def enumerate_results(
166167
Returns:
167168
iterator on tuple(result kind, name, value, node.op_type or None)
168169
"""
169-
assert isinstance(self.evaluator, ExtendedReferenceEvaluator), (
170+
assert isinstance(self.evaluator, ReferenceEvaluator), (
170171
f"This implementation only works with "
171-
f"ExtendedReferenceEvaluator not {type(self.evaluator)}"
172+
f"ReferenceEvaluator not {type(self.evaluator)}"
172173
)
173174
attributes = {}
174175
if output_names is None:
@@ -595,6 +596,7 @@ def compare_onnx_execution(
595596
raise_exc: bool = True,
596597
mode: str = "execute",
597598
keep_tensor: bool = False,
599+
cls: Optional[type[ReferenceEvaluator]] = None,
598600
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
599601
"""
600602
Compares the execution of two onnx models.
@@ -611,6 +613,7 @@ def compare_onnx_execution(
611613
:param mode: the model should be executed but the function can be executed
612614
but the comparison may append on nodes only
613615
:param keep_tensor: keeps the tensor in order to compute a precise distance
616+
:param cls: evaluator class to use
614617
:return: four results, a sequence of results
615618
for the first model and the second model,
616619
the alignment between the two, DistanceExecution
@@ -634,15 +637,15 @@ def compare_onnx_execution(
634637
print(f"[compare_onnx_execution] execute with {len(inputs)} inputs")
635638
print("[compare_onnx_execution] execute first model")
636639
res1 = list(
637-
YieldEvaluator(model1).enumerate_summarized(
640+
YieldEvaluator(model1, cls=cls).enumerate_summarized(
638641
None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
639642
)
640643
)
641644
if verbose:
642645
print(f"[compare_onnx_execution] got {len(res1)} results")
643646
print("[compare_onnx_execution] execute second model")
644647
res2 = list(
645-
YieldEvaluator(model2).enumerate_summarized(
648+
YieldEvaluator(model2, cls=cls).enumerate_summarized(
646649
None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
647650
)
648651
)

0 commit comments

Comments
 (0)