From b11db3c1475bb999dffb49cbff852215ead5dd1e Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Jan 2025 16:28:14 +0100 Subject: [PATCH 1/3] Export evaluator type in compare_onnx_execution --- LICENSE.txt | 2 +- onnx_array_api/reference/evaluator_yield.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index e027853..1a46a8e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2023-2024, Xavier Dupré +Copyright (c) 2023-2025, Xavier Dupré Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 5b77e8b..7d16be3 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -3,6 +3,7 @@ from enum import IntEnum import numpy as np from onnx import ModelProto, TensorProto, ValueInfoProto, load +from onnx.reference import ReferenceEvaluator from onnx.helper import tensor_dtype_to_np_dtype from onnx.shape_inference import infer_shapes from . import to_array_extended @@ -166,9 +167,9 @@ def enumerate_results( Returns: iterator on tuple(result kind, name, value, node.op_type or None) """ - assert isinstance(self.evaluator, ExtendedReferenceEvaluator), ( + assert isinstance(self.evaluator, ReferenceEvaluator), ( f"This implementation only works with " - f"ExtendedReferenceEvaluator not {type(self.evaluator)}" + f"ReferenceEvaluator not {type(self.evaluator)}" ) attributes = {} if output_names is None: @@ -595,6 +596,7 @@ def compare_onnx_execution( raise_exc: bool = True, mode: str = "execute", keep_tensor: bool = False, + cls: Optional[type[ReferenceEvaluator]] = None, ) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]: """ Compares the execution of two onnx models. @@ -611,6 +613,7 @@ def compare_onnx_execution( :param mode: the model should be executed but the function can be executed but the comparison may append on nodes only :param keep_tensor: keeps the tensor in order to compute a precise distance + :param cls: evaluator class to use :return: four results, a sequence of results for the first model and the second model, the alignment between the two, DistanceExecution @@ -634,7 +637,7 @@ def compare_onnx_execution( print(f"[compare_onnx_execution] execute with {len(inputs)} inputs") print("[compare_onnx_execution] execute first model") res1 = list( - YieldEvaluator(model1).enumerate_summarized( + YieldEvaluator(model1, cls=cls).enumerate_summarized( None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor ) ) @@ -642,7 +645,7 @@ def compare_onnx_execution( print(f"[compare_onnx_execution] got {len(res1)} results") print("[compare_onnx_execution] execute second model") res2 = list( - YieldEvaluator(model2).enumerate_summarized( + YieldEvaluator(model2, cls=cls).enumerate_summarized( None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor ) ) From 5e73e8d8395945c5432c44bdc44f606063794463 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Jan 2025 17:54:19 +0100 Subject: [PATCH 2/3] doc --- onnx_array_api/reference/evaluator_yield.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 7d16be3..82da956 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -139,17 +139,22 @@ class YieldEvaluator: :param onnx_model: model to run :param recursive: dig into subgraph and functions as well + :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator` """ def __init__( self, onnx_model: ModelProto, recursive: bool = False, - cls=ExtendedReferenceEvaluator, + cls: Optional[type[ExtendedReferenceEvaluator]] = None, ): assert not recursive, "recursive=True is not yet implemented" self.onnx_model = onnx_model - self.evaluator = cls(onnx_model) if cls is not None else None + self.evaluator = ( + cls(onnx_model) + if cls is not None + else ExtendedReferenceEvaluator(onnx_model) + ) def enumerate_results( self, From f3c75299a167022e26b1a390961ec5034e9d2779 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Jan 2025 17:55:06 +0100 Subject: [PATCH 3/3] doc --- onnx_array_api/reference/evaluator_yield.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 82da956..6ae005c 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -139,7 +139,8 @@ class YieldEvaluator: :param onnx_model: model to run :param recursive: dig into subgraph and functions as well - :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator` + :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator + ` """ def __init__(