3
3
from enum import IntEnum
4
4
import numpy as np
5
5
from onnx import ModelProto , TensorProto , ValueInfoProto , load
6
+ from onnx .reference import ReferenceEvaluator
6
7
from onnx .helper import tensor_dtype_to_np_dtype
7
8
from onnx .shape_inference import infer_shapes
8
9
from . import to_array_extended
@@ -166,9 +167,9 @@ def enumerate_results(
166
167
Returns:
167
168
iterator on tuple(result kind, name, value, node.op_type or None)
168
169
"""
169
- assert isinstance (self .evaluator , ExtendedReferenceEvaluator ), (
170
+ assert isinstance (self .evaluator , ReferenceEvaluator ), (
170
171
f"This implementation only works with "
171
- f"ExtendedReferenceEvaluator not { type (self .evaluator )} "
172
+ f"ReferenceEvaluator not { type (self .evaluator )} "
172
173
)
173
174
attributes = {}
174
175
if output_names is None :
@@ -595,6 +596,7 @@ def compare_onnx_execution(
595
596
raise_exc : bool = True ,
596
597
mode : str = "execute" ,
597
598
keep_tensor : bool = False ,
599
+ cls : Optional [type [ReferenceEvaluator ]] = None ,
598
600
) -> Tuple [List [ResultExecution ], List [ResultExecution ], List [Tuple [int , int ]]]:
599
601
"""
600
602
Compares the execution of two onnx models.
@@ -611,6 +613,7 @@ def compare_onnx_execution(
611
613
:param mode: the model should be executed but the function can be executed
612
614
but the comparison may append on nodes only
613
615
:param keep_tensor: keeps the tensor in order to compute a precise distance
616
+ :param cls: evaluator class to use
614
617
:return: four results, a sequence of results
615
618
for the first model and the second model,
616
619
the alignment between the two, DistanceExecution
@@ -634,15 +637,15 @@ def compare_onnx_execution(
634
637
print (f"[compare_onnx_execution] execute with { len (inputs )} inputs" )
635
638
print ("[compare_onnx_execution] execute first model" )
636
639
res1 = list (
637
- YieldEvaluator (model1 ).enumerate_summarized (
640
+ YieldEvaluator (model1 , cls = cls ).enumerate_summarized (
638
641
None , feeds1 , raise_exc = raise_exc , keep_tensor = keep_tensor
639
642
)
640
643
)
641
644
if verbose :
642
645
print (f"[compare_onnx_execution] got { len (res1 )} results" )
643
646
print ("[compare_onnx_execution] execute second model" )
644
647
res2 = list (
645
- YieldEvaluator (model2 ).enumerate_summarized (
648
+ YieldEvaluator (model2 , cls = cls ).enumerate_summarized (
646
649
None , feeds2 , raise_exc = raise_exc , keep_tensor = keep_tensor
647
650
)
648
651
)
0 commit comments