From 75929c91a3b72f82c05422fb1172cc43d7d9c37b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Tue, 11 Jul 2023 21:54:45 +0200 Subject: [PATCH] PERF speed up confusion matrix calculation --- sklearn/metrics/_classification.py | 93 +++++++++++++++++++++++++----- sklearn/utils/multiclass.py | 10 +++- 2 files changed, 85 insertions(+), 18 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index e0ef359aa1a85..487a8e9e4e72a 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -38,6 +38,7 @@ check_consistent_length, column_or_1d, ) +from ..utils._array_api import get_namespace from ..utils._param_validation import Interval, Options, StrOptions, validate_params from ..utils.extmath import _nanaverage from ..utils.multiclass import type_of_target, unique_labels @@ -54,7 +55,7 @@ def _check_zero_division(zero_division): return np.nan -def _check_targets(y_true, y_pred): +def _check_targets(y_true, y_pred, unique_y_true=None, unique_y_pred=None): """Check that y_true and y_pred belong to the same classification task. This converts multiclass or binary types to a common shape, and raises a @@ -71,6 +72,12 @@ def _check_targets(y_true, y_pred): y_pred : array-like + unique_y_true : array-like, default=None + Infered from y_true if None. + + unique_y_pred : array-like, default=None + Infered from y_pred if None. + Returns ------- type_true : one of {'multilabel-indicator', 'multiclass', 'binary'} @@ -82,8 +89,13 @@ def _check_targets(y_true, y_pred): y_pred : array or indicator matrix """ check_consistent_length(y_true, y_pred) - type_true = type_of_target(y_true, input_name="y_true") - type_pred = type_of_target(y_pred, input_name="y_pred") + xp, _ = get_namespace(y_true, y_pred) + unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true + unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred + n_classes = len(unique_labels(unique_y_true, unique_y_pred)) + + type_true = type_of_target(y_true, input_name="y_true", n_classes=n_classes) + type_pred = type_of_target(y_pred, input_name="y_pred", n_classes=n_classes) y_type = {type_true, type_pred} if y_type == {"binary", "multiclass"}: @@ -108,7 +120,7 @@ def _check_targets(y_true, y_pred): y_pred = column_or_1d(y_pred) if y_type == "binary": try: - unique_values = np.union1d(y_true, y_pred) + unique_values = np.union1d(unique_y_true, unique_y_pred) except TypeError as e: # We expect y_true and y_pred to be of the same data type. # If `y_true` was provided to the classifier as strings, @@ -116,8 +128,8 @@ def _check_targets(y_true, y_pred): # strings. So we raise a meaningful error raise TypeError( "Labels in y_true and y_pred should be of the same type. " - f"Got y_true={np.unique(y_true)} and " - f"y_pred={np.unique(y_pred)}. Make sure that the " + f"Got y_true={unique_y_true} and " + f"y_pred={xp.unique_values(y_pred)}. Make sure that the " "predictions provided by the classifier coincides with " "the true labels." ) from e @@ -403,7 +415,14 @@ def confusion_matrix( prefer_skip_nested_validation=True, ) def multilabel_confusion_matrix( - y_true, y_pred, *, sample_weight=None, labels=None, samplewise=False + y_true, + y_pred, + *, + sample_weight=None, + labels=None, + samplewise=False, + unique_y_true=None, + unique_y_pred=None, ): """Compute a confusion matrix for each class or sample. @@ -444,6 +463,12 @@ def multilabel_confusion_matrix( samplewise : bool, default=False In the multilabel case, this calculates a confusion matrix per sample. + unique_y_true : array-like, default=None + The unique values in y_true. If None, it will be inferred from y_true. + + unique_y_pred : array-like, default=None + The unique values in y_pred. If None, it will be inferred from y_pred. + Returns ------- multi_confusion : ndarray of shape (n_outputs, 2, 2) @@ -502,7 +527,13 @@ def multilabel_confusion_matrix( [[2, 1], [1, 2]]]) """ - y_type, y_true, y_pred = _check_targets(y_true, y_pred) + xp, _ = get_namespace(y_true, y_pred) + unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true + unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred + + y_type, y_true, y_pred = _check_targets( + y_true, y_pred, unique_y_true, unique_y_pred + ) if sample_weight is not None: sample_weight = column_or_1d(sample_weight) check_consistent_length(y_true, y_pred, sample_weight) @@ -510,7 +541,7 @@ def multilabel_confusion_matrix( if y_type not in ("binary", "multiclass", "multilabel-indicator"): raise ValueError("%s is not supported" % y_type) - present_labels = unique_labels(y_true, y_pred) + present_labels = unique_labels(unique_y_true, unique_y_pred) if labels is None: labels = present_labels n_labels = None @@ -1487,7 +1518,9 @@ def _warn_prf(average, modifier, msg_start, result_size): warnings.warn(msg, UndefinedMetricWarning, stacklevel=2) -def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): +def _check_set_wise_labels( + y_true, y_pred, average, labels, pos_label, unique_y_true, unique_y_pred +): """Validation associated with set-wise metrics. Returns identified labels. @@ -1496,10 +1529,12 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): if average not in average_options and average != "binary": raise ValueError("average has to be one of " + str(average_options)) - y_type, y_true, y_pred = _check_targets(y_true, y_pred) + y_type, y_true, y_pred = _check_targets( + y_true, y_pred, unique_y_true, unique_y_pred + ) # Convert to Python primitive type to avoid NumPy type / Python str # comparison. See https://github.com/numpy/numpy/issues/6784 - present_labels = unique_labels(y_true, y_pred).tolist() + present_labels = unique_labels(unique_y_true, unique_y_pred).tolist() if average == "binary": if y_type == "binary": if pos_label not in present_labels: @@ -1559,6 +1594,8 @@ def precision_recall_fscore_support( warn_for=("precision", "recall", "f-score"), sample_weight=None, zero_division="warn", + unique_y_true=None, + unique_y_pred=None, ): """Compute precision, recall, F-measure and support for each class. @@ -1656,6 +1693,14 @@ def precision_recall_fscore_support( .. versionadded:: 1.3 `np.nan` option was added. + unique_y_true : array-like, default=None + The unique values in ``y_true``. If ``None``, the unique values are + determined from the input. + + unique_y_pred : array-like, default=None + The unique values in ``y_pred``. If ``None``, the unique values are + determined from the input. + Returns ------- precision : float (if average is not None) or array of float, shape =\ @@ -1718,7 +1763,13 @@ def precision_recall_fscore_support( array([2, 2, 2])) """ zero_division_value = _check_zero_division(zero_division) - labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) + xp, _ = get_namespace(y_true, y_pred) + unique_y_true = xp.unique_values(y_true) if unique_y_true is None else unique_y_true + unique_y_pred = xp.unique_values(y_pred) if unique_y_pred is None else unique_y_pred + + labels = _check_set_wise_labels( + y_true, y_pred, average, labels, pos_label, unique_y_true, unique_y_pred + ) # Calculate tp_sum, pred_sum, true_sum ### samplewise = average == "samples" @@ -2535,11 +2586,17 @@ class 2 1.00 0.67 0.80 3 weighted avg 1.00 0.67 0.80 3 """ + xp, _ = get_namespace(y_true, y_pred) + unique_y_true = xp.unique_values(y_true) + unique_y_pred = xp.unique_values(y_pred) + unique_all = unique_labels(unique_y_true, unique_y_pred) - y_type, y_true, y_pred = _check_targets(y_true, y_pred) + y_type, y_true, y_pred = _check_targets( + y_true, y_pred, unique_y_true=unique_y_true, unique_y_pred=unique_y_pred + ) if labels is None: - labels = unique_labels(y_true, y_pred) + labels = unique_all labels_given = False else: labels = np.asarray(labels) @@ -2547,7 +2604,7 @@ class 2 1.00 0.67 0.80 3 # labelled micro average micro_is_accuracy = (y_type == "multiclass" or y_type == "binary") and ( - not labels_given or (set(labels) == set(unique_labels(y_true, y_pred))) + not labels_given or (set(labels) == set(unique_all)) ) if target_names is not None and len(labels) != len(target_names): @@ -2575,6 +2632,8 @@ class 2 1.00 0.67 0.80 3 average=None, sample_weight=sample_weight, zero_division=zero_division, + unique_y_true=unique_y_true, + unique_y_pred=unique_y_pred, ) rows = zip(target_names, p, r, f1, s) @@ -2614,6 +2673,8 @@ class 2 1.00 0.67 0.80 3 average=average, sample_weight=sample_weight, zero_division=zero_division, + unique_y_true=unique_y_true, + unique_y_pred=unique_y_pred, ) avg = [avg_p, avg_r, avg_f1, np.sum(s)] diff --git a/sklearn/utils/multiclass.py b/sklearn/utils/multiclass.py index 83492c852f745..c4d207feec4ac 100644 --- a/sklearn/utils/multiclass.py +++ b/sklearn/utils/multiclass.py @@ -219,7 +219,7 @@ def check_classification_targets(y): ) -def type_of_target(y, input_name=""): +def type_of_target(y, input_name="", n_classes=None): """Determine the type of data indicated by the target. Note that this type is the most specific type that can be inferred. @@ -242,6 +242,11 @@ def type_of_target(y, input_name=""): .. versionadded:: 1.1.0 + n_classes : int, default=None + Number of classes. Will be inferred from the data if not provided. + + .. versionadded:: 1.4 + Returns ------- target_type : str @@ -383,8 +388,9 @@ def type_of_target(y, input_name=""): return "continuous" + suffix # Check multiclass + n_classes = n_classes or xp.unique_values(y).shape[0] first_row = y[0] if not issparse(y) else y.getrow(0).data - if xp.unique_values(y).shape[0] > 2 or (y.ndim == 2 and len(first_row) > 1): + if n_classes > 2 or (y.ndim == 2 and len(first_row) > 1): # [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] return "multiclass" + suffix else: