diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index 3925c0cdedc6f..53e09829c1d41 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -1305,7 +1305,7 @@ ignoring permutations:: >>> labels_true = [0, 0, 0, 1, 1, 1] >>> labels_pred = [0, 0, 1, 1, 2, 2] >>> metrics.rand_score(labels_true, labels_pred) - np.float64(0.66...) + 0.66... The Rand index does not ensure to obtain a value close to 0.0 for a random labelling. The adjusted Rand index **corrects for chance** and @@ -1319,7 +1319,7 @@ labels, rename 2 to 3, and get the same score:: >>> labels_pred = [1, 1, 0, 0, 3, 3] >>> metrics.rand_score(labels_true, labels_pred) - np.float64(0.66...) + 0.66... >>> metrics.adjusted_rand_score(labels_true, labels_pred) 0.24... @@ -1328,7 +1328,7 @@ Furthermore, both :func:`rand_score` :func:`adjusted_rand_score` are thus be used as **consensus measures**:: >>> metrics.rand_score(labels_pred, labels_true) - np.float64(0.66...) + 0.66... >>> metrics.adjusted_rand_score(labels_pred, labels_true) 0.24... @@ -1348,7 +1348,7 @@ will not necessarily be close to zero.:: >>> labels_true = [0, 0, 0, 0, 0, 0, 1, 1] >>> labels_pred = [0, 1, 2, 3, 4, 5, 5, 6] >>> metrics.rand_score(labels_true, labels_pred) - np.float64(0.39...) + 0.39... >>> metrics.adjusted_rand_score(labels_true, labels_pred) -0.07... @@ -1644,16 +1644,16 @@ We can turn those concept as scores :func:`homogeneity_score` and >>> labels_pred = [0, 0, 1, 1, 2, 2] >>> metrics.homogeneity_score(labels_true, labels_pred) - np.float64(0.66...) + 0.66... >>> metrics.completeness_score(labels_true, labels_pred) - np.float64(0.42...) + 0.42... Their harmonic mean called **V-measure** is computed by :func:`v_measure_score`:: >>> metrics.v_measure_score(labels_true, labels_pred) - np.float64(0.51...) + 0.51... This function's formula is as follows: @@ -1662,12 +1662,12 @@ This function's formula is as follows: `beta` defaults to a value of 1.0, but for using a value less than 1 for beta:: >>> metrics.v_measure_score(labels_true, labels_pred, beta=0.6) - np.float64(0.54...) + 0.54... more weight will be attributed to homogeneity, and using a value greater than 1:: >>> metrics.v_measure_score(labels_true, labels_pred, beta=1.8) - np.float64(0.48...) + 0.48... more weight will be attributed to completeness. @@ -1678,14 +1678,14 @@ Homogeneity, completeness and V-measure can be computed at once using :func:`homogeneity_completeness_v_measure` as follows:: >>> metrics.homogeneity_completeness_v_measure(labels_true, labels_pred) - (np.float64(0.66...), np.float64(0.42...), np.float64(0.51...)) + (0.66..., 0.42..., 0.51...) The following clustering assignment is slightly better, since it is homogeneous but not complete:: >>> labels_pred = [0, 0, 0, 1, 2, 2] >>> metrics.homogeneity_completeness_v_measure(labels_true, labels_pred) - (np.float64(1.0), np.float64(0.68...), np.float64(0.81...)) + (1.0, 0.68..., 0.81...) .. note:: @@ -1815,7 +1815,7 @@ between two clusters. >>> labels_pred = [0, 0, 1, 1, 2, 2] >>> metrics.fowlkes_mallows_score(labels_true, labels_pred) - np.float64(0.47140...) + 0.47140... One can permute 0 and 1 in the predicted labels, rename 2 to 3 and get the same score:: @@ -1823,13 +1823,13 @@ the same score:: >>> labels_pred = [1, 1, 0, 0, 3, 3] >>> metrics.fowlkes_mallows_score(labels_true, labels_pred) - np.float64(0.47140...) + 0.47140... Perfect labeling is scored 1.0:: >>> labels_pred = labels_true[:] >>> metrics.fowlkes_mallows_score(labels_true, labels_pred) - np.float64(1.0) + 1.0 Bad (e.g. independent labelings) have zero scores:: @@ -1912,7 +1912,7 @@ cluster analysis. >>> kmeans_model = KMeans(n_clusters=3, random_state=1).fit(X) >>> labels = kmeans_model.labels_ >>> metrics.silhouette_score(X, labels, metric='euclidean') - np.float64(0.55...) + 0.55... .. topic:: Advantages: @@ -1969,7 +1969,7 @@ cluster analysis: >>> kmeans_model = KMeans(n_clusters=3, random_state=1).fit(X) >>> labels = kmeans_model.labels_ >>> metrics.calinski_harabasz_score(X, labels) - np.float64(561.59...) + 561.59... .. topic:: Advantages: @@ -2043,7 +2043,7 @@ cluster analysis as follows: >>> kmeans = KMeans(n_clusters=3, random_state=1).fit(X) >>> labels = kmeans.labels_ >>> davies_bouldin_score(X, labels) - np.float64(0.666...) + 0.666... .. topic:: Advantages: diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index f4ff3199d16e3..fd9b58c4ff33b 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -377,7 +377,7 @@ You can create your own custom scorer object using >>> import numpy as np >>> def my_custom_loss_func(y_true, y_pred): ... diff = np.abs(y_true - y_pred).max() - ... return np.log1p(diff) + ... return float(np.log1p(diff)) ... >>> # score will negate the return value of my_custom_loss_func, >>> # which will be np.log(2), 0.693, given the values for X @@ -389,9 +389,9 @@ You can create your own custom scorer object using >>> clf = DummyClassifier(strategy='most_frequent', random_state=0) >>> clf = clf.fit(X, y) >>> my_custom_loss_func(y, clf.predict(X)) - np.float64(0.69...) + 0.69... >>> score(clf, X, y) - np.float64(-0.69...) + -0.69... .. dropdown:: Custom scorer objects from scratch @@ -673,10 +673,10 @@ where :math:`k` is the number of guesses allowed and :math:`1(x)` is the ... [0.2, 0.4, 0.3], ... [0.7, 0.2, 0.1]]) >>> top_k_accuracy_score(y_true, y_score, k=2) - np.float64(0.75) + 0.75 >>> # Not normalizing gives the number of "correctly" classified samples >>> top_k_accuracy_score(y_true, y_score, k=2, normalize=False) - np.int64(3) + 3.0 .. _balanced_accuracy_score: @@ -786,7 +786,7 @@ and not for more than two annotators. >>> labeling1 = [2, 0, 2, 2, 0, 1] >>> labeling2 = [0, 0, 2, 2, 0, 2] >>> cohen_kappa_score(labeling1, labeling2) - np.float64(0.4285714285714286) + 0.4285714285714286 .. _confusion_matrix: @@ -837,9 +837,9 @@ false negatives and true positives as follows:: >>> y_true = [0, 0, 0, 1, 1, 1, 1, 1] >>> y_pred = [0, 1, 0, 1, 0, 1, 0, 1] - >>> tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + >>> tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel().tolist() >>> tn, fp, fn, tp - (np.int64(2), np.int64(1), np.int64(2), np.int64(3)) + (2, 1, 2, 3) .. rubric:: Examples @@ -1115,7 +1115,7 @@ Here are some small examples in binary classification:: >>> threshold array([0.1 , 0.35, 0.4 , 0.8 ]) >>> average_precision_score(y_true, y_scores) - np.float64(0.83...) + 0.83... @@ -1234,19 +1234,19 @@ In the binary case:: >>> y_pred = np.array([[1, 1, 1], ... [1, 0, 0]]) >>> jaccard_score(y_true[0], y_pred[0]) - np.float64(0.6666...) + 0.6666... In the 2D comparison case (e.g. image similarity): >>> jaccard_score(y_true, y_pred, average="micro") - np.float64(0.6) + 0.6 In the multilabel case with binary label indicators:: >>> jaccard_score(y_true, y_pred, average='samples') - np.float64(0.5833...) + 0.5833... >>> jaccard_score(y_true, y_pred, average='macro') - np.float64(0.6666...) + 0.6666... >>> jaccard_score(y_true, y_pred, average=None) array([0.5, 0.5, 1. ]) @@ -1258,9 +1258,9 @@ multilabel problem:: >>> jaccard_score(y_true, y_pred, average=None) array([1. , 0. , 0.33...]) >>> jaccard_score(y_true, y_pred, average='macro') - np.float64(0.44...) + 0.44... >>> jaccard_score(y_true, y_pred, average='micro') - np.float64(0.33...) + 0.33... .. _hinge_loss: @@ -1315,7 +1315,7 @@ with a svm classifier in a binary class problem:: >>> pred_decision array([-2.18..., 2.36..., 0.09...]) >>> hinge_loss([-1, 1, 1], pred_decision) - np.float64(0.3...) + 0.3... Here is an example demonstrating the use of the :func:`hinge_loss` function with a svm classifier in a multiclass problem:: @@ -1329,7 +1329,7 @@ with a svm classifier in a multiclass problem:: >>> pred_decision = est.decision_function([[-1], [2], [3]]) >>> y_true = [0, 2, 3] >>> hinge_loss(y_true, pred_decision, labels=labels) - np.float64(0.56...) + 0.56... .. _log_loss: @@ -1445,7 +1445,7 @@ function: >>> y_true = [+1, +1, +1, -1] >>> y_pred = [+1, -1, +1, +1] >>> matthews_corrcoef(y_true, y_pred) - np.float64(-0.33...) + -0.33... .. rubric:: References @@ -1640,12 +1640,12 @@ We can use the probability estimates corresponding to `clf.classes_[1]`. >>> y_score = clf.predict_proba(X)[:, 1] >>> roc_auc_score(y, y_score) - np.float64(0.99...) + 0.99... Otherwise, we can use the non-thresholded decision values >>> roc_auc_score(y, clf.decision_function(X)) - np.float64(0.99...) + 0.99... .. _roc_auc_multiclass: @@ -1951,13 +1951,13 @@ Here is a small example of usage of this function:: >>> y_prob = np.array([0.1, 0.9, 0.8, 0.4]) >>> y_pred = np.array([0, 1, 1, 0]) >>> brier_score_loss(y_true, y_prob) - np.float64(0.055) + 0.055 >>> brier_score_loss(y_true, 1 - y_prob, pos_label=0) - np.float64(0.055) + 0.055 >>> brier_score_loss(y_true_categorical, y_prob, pos_label="ham") - np.float64(0.055) + 0.055 >>> brier_score_loss(y_true, y_prob > 0.5) - np.float64(0.0) + 0.0 The Brier score can be used to assess how well a classifier is calibrated. However, a lower Brier score loss does not always mean a better calibration. @@ -2236,7 +2236,7 @@ Here is a small example of usage of this function:: >>> y_true = np.array([[1, 0, 0], [0, 0, 1]]) >>> y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]]) >>> coverage_error(y_true, y_score) - np.float64(2.5) + 2.5 .. _label_ranking_average_precision: @@ -2283,7 +2283,7 @@ Here is a small example of usage of this function:: >>> y_true = np.array([[1, 0, 0], [0, 0, 1]]) >>> y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]]) >>> label_ranking_average_precision_score(y_true, y_score) - np.float64(0.416...) + 0.416... .. _label_ranking_loss: @@ -2318,11 +2318,11 @@ Here is a small example of usage of this function:: >>> y_true = np.array([[1, 0, 0], [0, 0, 1]]) >>> y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]]) >>> label_ranking_loss(y_true, y_score) - np.float64(0.75...) + 0.75... >>> # With the following prediction, we have perfect and minimal loss >>> y_score = np.array([[1.0, 0.1, 0.2], [0.1, 0.2, 0.9]]) >>> label_ranking_loss(y_true, y_score) - np.float64(0.0) + 0.0 .. dropdown:: References @@ -2700,7 +2700,7 @@ function:: >>> y_true = [3, -0.5, 2, 7] >>> y_pred = [2.5, 0.0, 2, 8] >>> median_absolute_error(y_true, y_pred) - np.float64(0.5) + 0.5 @@ -2732,7 +2732,7 @@ Here is a small example of usage of the :func:`max_error` function:: >>> y_true = [3, 2, 7, 1] >>> y_pred = [9, 2, 7, 1] >>> max_error(y_true, y_pred) - np.int64(6) + 6.0 The :func:`max_error` does not support multioutput. @@ -3011,15 +3011,15 @@ of 0.0. >>> y_true = [3, -0.5, 2, 7] >>> y_pred = [2.5, 0.0, 2, 8] >>> d2_absolute_error_score(y_true, y_pred) - np.float64(0.764...) + 0.764... >>> y_true = [1, 2, 3] >>> y_pred = [1, 2, 3] >>> d2_absolute_error_score(y_true, y_pred) - np.float64(1.0) + 1.0 >>> y_true = [1, 2, 3] >>> y_pred = [2, 2, 2] >>> d2_absolute_error_score(y_true, y_pred) - np.float64(0.0) + 0.0 .. _visualization_regression_evaluation: diff --git a/sklearn/metrics/_base.py b/sklearn/metrics/_base.py index ee797e1bc4030..aa4150c88a978 100644 --- a/sklearn/metrics/_base.py +++ b/sklearn/metrics/_base.py @@ -118,7 +118,7 @@ def _average_binary_score(binary_metric, y_true, y_score, average, sample_weight # score from being affected by 0-weighted NaN elements. average_weight = np.asarray(average_weight) score[average_weight == 0] = 0 - return np.average(score, weights=average_weight) + return float(np.average(score, weights=average_weight)) else: return score diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index a010f602d274c..2a08a1893766e 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -333,9 +333,9 @@ def confusion_matrix( In the binary case, we can extract true positives, etc. as follows: - >>> tn, fp, fn, tp = confusion_matrix([0, 1, 0, 1], [1, 1, 1, 0]).ravel() + >>> tn, fp, fn, tp = confusion_matrix([0, 1, 0, 1], [1, 1, 1, 0]).ravel().tolist() >>> (tn, fp, fn, tp) - (np.int64(0), np.int64(2), np.int64(1), np.int64(1)) + (0, 2, 1, 1) """ y_true, y_pred = attach_unique(y_true, y_pred) y_type, y_true, y_pred = _check_targets(y_true, y_pred) @@ -737,7 +737,7 @@ class labels [2]_. >>> y1 = ["negative", "positive", "negative", "neutral", "positive"] >>> y2 = ["negative", "positive", "negative", "neutral", "negative"] >>> cohen_kappa_score(y1, y2) - np.float64(0.6875) + 0.6875 """ confusion = confusion_matrix(y1, y2, labels=labels, sample_weight=sample_weight) n_classes = confusion.shape[0] @@ -757,7 +757,7 @@ class labels [2]_. w_mat = (w_mat - w_mat.T) ** 2 k = np.sum(w_mat * confusion) / np.sum(w_mat * expected) - return 1 - k + return float(1 - k) @validate_params( @@ -898,19 +898,19 @@ def jaccard_score( In the binary case: >>> jaccard_score(y_true[0], y_pred[0]) - np.float64(0.6666...) + 0.6666... In the 2D comparison case (e.g. image similarity): >>> jaccard_score(y_true, y_pred, average="micro") - np.float64(0.6) + 0.6 In the multilabel case: >>> jaccard_score(y_true, y_pred, average='samples') - np.float64(0.5833...) + 0.5833... >>> jaccard_score(y_true, y_pred, average='macro') - np.float64(0.6666...) + 0.6666... >>> jaccard_score(y_true, y_pred, average=None) array([0.5, 0.5, 1. ]) @@ -957,7 +957,7 @@ def jaccard_score( weights = sample_weight else: weights = None - return np.average(jaccard, weights=weights) + return float(np.average(jaccard, weights=weights)) @validate_params( @@ -1029,7 +1029,7 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): >>> y_true = [+1, +1, +1, -1] >>> y_pred = [+1, -1, +1, +1] >>> matthews_corrcoef(y_true, y_pred) - np.float64(-0.33...) + -0.33... """ y_true, y_pred = attach_unique(y_true, y_pred) y_type, y_true, y_pred = _check_targets(y_true, y_pred) @@ -1054,7 +1054,7 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): if cov_ypyp * cov_ytyt == 0: return 0.0 else: - return cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp) + return float(cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)) @validate_params( @@ -2041,15 +2041,15 @@ class are present in `y_true`): both likelihood ratios are undefined. >>> from sklearn.metrics import class_likelihood_ratios >>> class_likelihood_ratios([0, 1, 0, 1, 0], [1, 1, 0, 0, 0], ... replace_undefined_by=1.0) - (np.float64(1.5), np.float64(0.75)) + (1.5, 0.75) >>> y_true = np.array(["non-cat", "cat", "non-cat", "cat", "non-cat"]) >>> y_pred = np.array(["cat", "cat", "non-cat", "non-cat", "non-cat"]) >>> class_likelihood_ratios(y_true, y_pred, replace_undefined_by=1.0) - (np.float64(1.33...), np.float64(0.66...)) + (1.33..., 0.66...) >>> y_true = np.array(["non-zebra", "zebra", "non-zebra", "zebra", "non-zebra"]) >>> y_pred = np.array(["zebra", "zebra", "non-zebra", "non-zebra", "non-zebra"]) >>> class_likelihood_ratios(y_true, y_pred, replace_undefined_by=1.0) - (np.float64(1.5), np.float64(0.75)) + (1.5, 0.75) To avoid ambiguities, use the notation `labels=[negative_class, positive_class]` @@ -2058,7 +2058,7 @@ class are present in `y_true`): both likelihood ratios are undefined. >>> y_pred = np.array(["cat", "cat", "non-cat", "non-cat", "non-cat"]) >>> class_likelihood_ratios(y_true, y_pred, labels=["non-cat", "cat"], ... replace_undefined_by=1.0) - (np.float64(1.5), np.float64(0.75)) + (1.5, 0.75) """ # TODO(1.9): When `raise_warning` is removed, the following changes need to be made: # The checks for `raise_warning==True` need to be removed and we will always warn, @@ -2210,7 +2210,7 @@ class are present in `y_true`): both likelihood ratios are undefined. else: negative_likelihood_ratio = neg_num / neg_denom - return positive_likelihood_ratio, negative_likelihood_ratio + return float(positive_likelihood_ratio), float(negative_likelihood_ratio) @validate_params( @@ -2652,7 +2652,7 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals >>> y_true = [0, 1, 0, 0, 1, 0] >>> y_pred = [0, 1, 0, 0, 0, 1] >>> balanced_accuracy_score(y_true, y_pred) - np.float64(0.625) + 0.625 """ C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) with np.errstate(divide="ignore", invalid="ignore"): @@ -2666,7 +2666,7 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals chance = 1 / n_classes score -= chance score /= 1 - chance - return score + return float(score) @validate_params( @@ -3004,7 +3004,9 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None): if y_type.startswith("multilabel"): n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight) - return n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average) + return float( + n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average) + ) elif y_type in ["binary", "multiclass"]: return float(_average(y_true != y_pred, weights=sample_weight, normalize=True)) @@ -3241,7 +3243,7 @@ def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None): >>> pred_decision array([-2.18..., 2.36..., 0.09...]) >>> hinge_loss([-1, 1, 1], pred_decision) - np.float64(0.30...) + 0.30... In the multiclass case: @@ -3255,7 +3257,7 @@ def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None): >>> pred_decision = est.decision_function([[-1], [2], [3]]) >>> y_true = [0, 2, 3] >>> hinge_loss(y_true, pred_decision, labels=labels) - np.float64(0.56...) + 0.56... """ check_consistent_length(y_true, pred_decision, sample_weight) pred_decision = check_array(pred_decision, ensure_2d=False) @@ -3317,7 +3319,7 @@ def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None): losses = 1 - margin # The hinge_loss doesn't penalize good enough predictions. np.clip(losses, 0, None, out=losses) - return np.average(losses, weights=sample_weight) + return float(np.average(losses, weights=sample_weight)) @validate_params( @@ -3401,13 +3403,13 @@ def brier_score_loss( >>> y_true_categorical = np.array(["spam", "ham", "ham", "spam"]) >>> y_prob = np.array([0.1, 0.9, 0.8, 0.3]) >>> brier_score_loss(y_true, y_prob) - np.float64(0.037...) + 0.037... >>> brier_score_loss(y_true, 1-y_prob, pos_label=0) - np.float64(0.037...) + 0.037... >>> brier_score_loss(y_true_categorical, y_prob, pos_label="ham") - np.float64(0.037...) + 0.037... >>> brier_score_loss(y_true, np.array(y_prob) > 0.5) - np.float64(0.0) + 0.0 """ # TODO(1.7): remove in 1.7 and reset y_proba to be required # Note: validate params will raise an error if y_prob is not array-like, @@ -3456,7 +3458,7 @@ def brier_score_loss( else: raise y_true = np.array(y_true == pos_label, int) - return np.average((y_true - y_proba) ** 2, weights=sample_weight) + return float(np.average((y_true - y_proba) ** 2, weights=sample_weight)) @validate_params( @@ -3549,4 +3551,4 @@ def d2_log_loss_score(y_true, y_pred, *, sample_weight=None, labels=None): labels=labels, ) - return 1 - (numerator / denominator) + return float(1 - (numerator / denominator)) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 0303eece69573..f12052867a781 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -77,7 +77,7 @@ def auc(x, y): >>> pred = np.array([0.1, 0.4, 0.35, 0.8]) >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=2) >>> metrics.auc(fpr, tpr) - np.float64(0.75) + 0.75 """ check_consistent_length(x, y) x = column_or_1d(x) @@ -103,7 +103,7 @@ def auc(x, y): # scalar by default for numpy.memmap instances contrary to # regular numpy.ndarray instances. area = area.dtype.type(area) - return area + return float(area) @validate_params( @@ -204,7 +204,7 @@ def average_precision_score( >>> y_true = np.array([0, 0, 1, 1]) >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) >>> average_precision_score(y_true, y_scores) - np.float64(0.83...) + 0.83... >>> y_true = np.array([0, 0, 1, 1, 2, 2]) >>> y_scores = np.array([ ... [0.7, 0.2, 0.1], @@ -215,7 +215,7 @@ def average_precision_score( ... [0.1, 0.2, 0.7], ... ]) >>> average_precision_score(y_true, y_scores) - np.float64(0.77...) + 0.77... """ def _binary_uninterpolated_average_precision( @@ -228,7 +228,7 @@ def _binary_uninterpolated_average_precision( # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve. # Due to numerical error, we can get `-0.0` and we therefore clip it. - return max(0.0, -np.sum(np.diff(recall) * np.array(precision)[:-1])) + return float(max(0.0, -np.sum(np.diff(recall) * np.array(precision)[:-1]))) y_type = type_of_target(y_true, input_name="y_true") @@ -583,9 +583,9 @@ class scores must correspond to the order of ``labels``, >>> X, y = load_breast_cancer(return_X_y=True) >>> clf = LogisticRegression(solver="liblinear", random_state=0).fit(X, y) >>> roc_auc_score(y, clf.predict_proba(X)[:, 1]) - np.float64(0.99...) + 0.99... >>> roc_auc_score(y, clf.decision_function(X)) - np.float64(0.99...) + 0.99... Multiclass case: @@ -593,7 +593,7 @@ class scores must correspond to the order of ``labels``, >>> X, y = load_iris(return_X_y=True) >>> clf = LogisticRegression(solver="liblinear").fit(X, y) >>> roc_auc_score(y, clf.predict_proba(X), multi_class='ovr') - np.float64(0.99...) + 0.99... Multilabel case: @@ -1248,7 +1248,7 @@ def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None >>> y_true = np.array([[1, 0, 0], [0, 0, 1]]) >>> y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]]) >>> label_ranking_average_precision_score(y_true, y_score) - np.float64(0.416...) + 0.416... """ check_consistent_length(y_true, y_score, sample_weight) y_true = check_array(y_true, ensure_2d=False, accept_sparse="csr") @@ -1294,7 +1294,7 @@ def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None else: out /= np.sum(sample_weight) - return out + return float(out) @validate_params( @@ -1353,7 +1353,7 @@ def coverage_error(y_true, y_score, *, sample_weight=None): >>> y_true = [[1, 0, 0], [0, 1, 1]] >>> y_score = [[1, 0, 0], [0, 1, 1]] >>> coverage_error(y_true, y_score) - np.float64(1.5) + 1.5 """ y_true = check_array(y_true, ensure_2d=True) y_score = check_array(y_score, ensure_2d=True) @@ -1371,7 +1371,7 @@ def coverage_error(y_true, y_score, *, sample_weight=None): coverage = (y_score >= y_min_relevant).sum(axis=1) coverage = coverage.filled(0) - return np.average(coverage, weights=sample_weight) + return float(np.average(coverage, weights=sample_weight)) @validate_params( @@ -1432,7 +1432,7 @@ def label_ranking_loss(y_true, y_score, *, sample_weight=None): >>> y_true = [[1, 0, 0], [0, 0, 1]] >>> y_score = [[0.75, 0.5, 1], [1, 0.2, 0.1]] >>> label_ranking_loss(y_true, y_score) - np.float64(0.75...) + 0.75... """ y_true = check_array(y_true, ensure_2d=False, accept_sparse="csr") y_score = check_array(y_score, ensure_2d=False) @@ -1473,7 +1473,7 @@ def label_ranking_loss(y_true, y_score, *, sample_weight=None): # be consider as correct, i.e. the ranking doesn't matter. loss[np.logical_or(n_positives == 0, n_positives == n_labels)] = 0.0 - return np.average(loss, weights=sample_weight) + return float(np.average(loss, weights=sample_weight)) def _dcg_sample_scores(y_true, y_score, k=None, log_base=2, ignore_ties=False): @@ -1688,32 +1688,34 @@ def dcg_score( >>> # we predict scores for the answers >>> scores = np.asarray([[.1, .2, .3, 4, 70]]) >>> dcg_score(true_relevance, scores) - np.float64(9.49...) + 9.49... >>> # we can set k to truncate the sum; only top k answers contribute >>> dcg_score(true_relevance, scores, k=2) - np.float64(5.63...) + 5.63... >>> # now we have some ties in our prediction >>> scores = np.asarray([[1, 0, 0, 0, 1]]) >>> # by default ties are averaged, so here we get the average true >>> # relevance of our top predictions: (10 + 5) / 2 = 7.5 >>> dcg_score(true_relevance, scores, k=1) - np.float64(7.5) + 7.5 >>> # we can choose to ignore ties for faster results, but only >>> # if we know there aren't ties in our scores, otherwise we get >>> # wrong results: >>> dcg_score(true_relevance, ... scores, k=1, ignore_ties=True) - np.float64(5.0) + 5.0 """ y_true = check_array(y_true, ensure_2d=False) y_score = check_array(y_score, ensure_2d=False) check_consistent_length(y_true, y_score, sample_weight) _check_dcg_target_type(y_true) - return np.average( - _dcg_sample_scores( - y_true, y_score, k=k, log_base=log_base, ignore_ties=ignore_ties - ), - weights=sample_weight, + return float( + np.average( + _dcg_sample_scores( + y_true, y_score, k=k, log_base=log_base, ignore_ties=ignore_ties + ), + weights=sample_weight, + ) ) @@ -1848,29 +1850,29 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False >>> # we predict some scores (relevance) for the answers >>> scores = np.asarray([[.1, .2, .3, 4, 70]]) >>> ndcg_score(true_relevance, scores) - np.float64(0.69...) + 0.69... >>> scores = np.asarray([[.05, 1.1, 1., .5, .0]]) >>> ndcg_score(true_relevance, scores) - np.float64(0.49...) + 0.49... >>> # we can set k to truncate the sum; only top k answers contribute. >>> ndcg_score(true_relevance, scores, k=4) - np.float64(0.35...) + 0.35... >>> # the normalization takes k into account so a perfect answer >>> # would still get 1.0 >>> ndcg_score(true_relevance, true_relevance, k=4) - np.float64(1.0...) + 1.0... >>> # now we have some ties in our prediction >>> scores = np.asarray([[1, 0, 0, 0, 1]]) >>> # by default ties are averaged, so here we get the average (normalized) >>> # true relevance of our top predictions: (10 / 10 + 5 / 10) / 2 = .75 >>> ndcg_score(true_relevance, scores, k=1) - np.float64(0.75...) + 0.75... >>> # we can choose to ignore ties for faster results, but only >>> # if we know there aren't ties in our scores, otherwise we get >>> # wrong results: >>> ndcg_score(true_relevance, ... scores, k=1, ignore_ties=True) - np.float64(0.5...) + 0.5... """ y_true = check_array(y_true, ensure_2d=False) y_score = check_array(y_score, ensure_2d=False) @@ -1885,7 +1887,7 @@ def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False ) _check_dcg_target_type(y_true) gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties) - return np.average(gain, weights=sample_weight) + return float(np.average(gain, weights=sample_weight)) @validate_params( @@ -1973,10 +1975,10 @@ def top_k_accuracy_score( ... [0.2, 0.4, 0.3], # 2 is in top 2 ... [0.7, 0.2, 0.1]]) # 2 isn't in top 2 >>> top_k_accuracy_score(y_true, y_score, k=2) - np.float64(0.75) + 0.75 >>> # Not normalizing gives the number of "correctly" classified samples >>> top_k_accuracy_score(y_true, y_score, k=2, normalize=False) - np.int64(3) + 3.0 """ y_true = check_array(y_true, ensure_2d=False, dtype=None) y_true = column_or_1d(y_true) @@ -2055,8 +2057,8 @@ def top_k_accuracy_score( hits = (y_true_encoded == sorted_pred[:, :k].T).any(axis=0) if normalize: - return np.average(hits, weights=sample_weight) + return float(np.average(hits, weights=sample_weight)) elif sample_weight is None: - return np.sum(hits) + return float(np.sum(hits)) else: - return np.dot(hits, sample_weight) + return float(np.dot(hits, sample_weight)) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index feab48e482c5b..65a3073f3691c 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -897,15 +897,15 @@ def median_absolute_error( >>> y_true = [3, -0.5, 2, 7] >>> y_pred = [2.5, 0.0, 2, 8] >>> median_absolute_error(y_true, y_pred) - np.float64(0.5) + 0.5 >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> median_absolute_error(y_true, y_pred) - np.float64(0.75) + 0.75 >>> median_absolute_error(y_true, y_pred, multioutput='raw_values') array([0.5, 1. ]) >>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) - np.float64(0.85) + 0.85 """ y_type, y_true, y_pred, multioutput = _check_reg_targets( y_true, y_pred, multioutput @@ -924,7 +924,7 @@ def median_absolute_error( # pass None as weights to np.average: uniform mean multioutput = None - return np.average(output_errors, weights=multioutput) + return float(np.average(output_errors, weights=multioutput)) def _assemble_r2_explained_variance( @@ -1335,13 +1335,13 @@ def max_error(y_true, y_pred): >>> y_true = [3, 2, 7, 1] >>> y_pred = [4, 2, 7, 1] >>> max_error(y_true, y_pred) - np.int64(1) + 1.0 """ xp, _ = get_namespace(y_true, y_pred) y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None, xp=xp) if y_type == "continuous-multioutput": raise ValueError("Multioutput not supported in max_error") - return xp.max(xp.abs(y_true - y_pred)) + return float(xp.max(xp.abs(y_true - y_pred))) def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power): @@ -1758,13 +1758,13 @@ def d2_pinball_score( >>> y_true = [1, 2, 3] >>> y_pred = [1, 3, 3] >>> d2_pinball_score(y_true, y_pred) - np.float64(0.5) + 0.5 >>> d2_pinball_score(y_true, y_pred, alpha=0.9) - np.float64(0.772...) + 0.772... >>> d2_pinball_score(y_true, y_pred, alpha=0.1) - np.float64(-1.045...) + -1.045... >>> d2_pinball_score(y_true, y_true, alpha=0.1) - np.float64(1.0) + 1.0 """ y_type, y_true, y_pred, multioutput = _check_reg_targets( y_true, y_pred, multioutput @@ -1823,7 +1823,7 @@ def d2_pinball_score( else: avg_weights = multioutput - return np.average(output_scores, weights=avg_weights) + return float(np.average(output_scores, weights=avg_weights)) @validate_params( @@ -1901,25 +1901,25 @@ def d2_absolute_error_score( >>> y_true = [3, -0.5, 2, 7] >>> y_pred = [2.5, 0.0, 2, 8] >>> d2_absolute_error_score(y_true, y_pred) - np.float64(0.764...) + 0.764... >>> y_true = [[0.5, 1], [-1, 1], [7, -6]] >>> y_pred = [[0, 2], [-1, 2], [8, -5]] >>> d2_absolute_error_score(y_true, y_pred, multioutput='uniform_average') - np.float64(0.691...) + 0.691... >>> d2_absolute_error_score(y_true, y_pred, multioutput='raw_values') array([0.8125 , 0.57142857]) >>> y_true = [1, 2, 3] >>> y_pred = [1, 2, 3] >>> d2_absolute_error_score(y_true, y_pred) - np.float64(1.0) + 1.0 >>> y_true = [1, 2, 3] >>> y_pred = [2, 2, 2] >>> d2_absolute_error_score(y_true, y_pred) - np.float64(0.0) + 0.0 >>> y_true = [1, 2, 3] >>> y_pred = [3, 2, 1] >>> d2_absolute_error_score(y_true, y_pred) - np.float64(-1.0) + -1.0 """ return d2_pinball_score( y_true, y_pred, sample_weight=sample_weight, alpha=0.5, multioutput=multioutput diff --git a/sklearn/metrics/cluster/_bicluster.py b/sklearn/metrics/cluster/_bicluster.py index 49aa8a37be21b..bb306c025b694 100644 --- a/sklearn/metrics/cluster/_bicluster.py +++ b/sklearn/metrics/cluster/_bicluster.py @@ -103,7 +103,7 @@ def consensus_score(a, b, *, similarity="jaccard"): >>> a = ([[True, False], [False, True]], [[False, True], [True, False]]) >>> b = ([[False, True], [True, False]], [[True, False], [False, True]]) >>> consensus_score(a, b, similarity='jaccard') - np.float64(1.0) + 1.0 """ if similarity == "jaccard": similarity = _jaccard @@ -111,4 +111,4 @@ def consensus_score(a, b, *, similarity="jaccard"): row_indices, col_indices = linear_sum_assignment(1.0 - matrix) n_a = len(a[0]) n_b = len(b[0]) - return matrix[row_indices, col_indices].sum() / max(n_a, n_b) + return float(matrix[row_indices, col_indices].sum() / max(n_a, n_b)) diff --git a/sklearn/metrics/cluster/_supervised.py b/sklearn/metrics/cluster/_supervised.py index e9ee22056cb5e..88a8206f9c734 100644 --- a/sklearn/metrics/cluster/_supervised.py +++ b/sklearn/metrics/cluster/_supervised.py @@ -323,7 +323,7 @@ def rand_score(labels_true, labels_pred): are complete but may not always be pure, hence penalized: >>> rand_score([0, 0, 1, 2], [0, 0, 1, 1]) - np.float64(0.83...) + 0.83... """ contingency = pair_confusion_matrix(labels_true, labels_pred) numerator = contingency.diagonal().sum() @@ -335,7 +335,7 @@ def rand_score(labels_true, labels_pred): # cluster. These are perfect matches hence return 1.0. return 1.0 - return numerator / denominator + return float(numerator / denominator) @validate_params( @@ -522,7 +522,7 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, *, beta=1.0): >>> from sklearn.metrics import homogeneity_completeness_v_measure >>> y_true, y_pred = [0, 0, 1, 1, 2, 2], [0, 0, 1, 2, 2, 2] >>> homogeneity_completeness_v_measure(y_true, y_pred) - (np.float64(0.71...), np.float64(0.77...), np.float64(0.73...)) + (0.71..., 0.77..., 0.73...) """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) @@ -548,7 +548,7 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, *, beta=1.0): / (beta * homogeneity + completeness) ) - return homogeneity, completeness, v_measure_score + return float(homogeneity), float(completeness), float(v_measure_score) @validate_params( @@ -606,7 +606,7 @@ def homogeneity_score(labels_true, labels_pred): >>> from sklearn.metrics.cluster import homogeneity_score >>> homogeneity_score([0, 0, 1, 1], [1, 1, 0, 0]) - np.float64(1.0) + 1.0 Non-perfect labelings that further split classes into more clusters can be perfectly homogeneous:: @@ -682,7 +682,7 @@ def completeness_score(labels_true, labels_pred): >>> from sklearn.metrics.cluster import completeness_score >>> completeness_score([0, 0, 1, 1], [1, 1, 0, 0]) - np.float64(1.0) + 1.0 Non-perfect labelings that assign all classes members to the same clusters are still complete:: @@ -771,9 +771,9 @@ def v_measure_score(labels_true, labels_pred, *, beta=1.0): >>> from sklearn.metrics.cluster import v_measure_score >>> v_measure_score([0, 0, 1, 1], [0, 0, 1, 1]) - np.float64(1.0) + 1.0 >>> v_measure_score([0, 0, 1, 1], [1, 1, 0, 0]) - np.float64(1.0) + 1.0 Labelings that assign all classes members to the same clusters are complete but not homogeneous, hence penalized:: @@ -879,7 +879,7 @@ def mutual_info_score(labels_true, labels_pred, *, contingency=None): >>> labels_true = [0, 1, 1, 0, 1, 0] >>> labels_pred = [0, 1, 0, 0, 1, 1] >>> mutual_info_score(labels_true, labels_pred) - np.float64(0.056...) + 0.056... """ if contingency is None: labels_true, labels_pred = check_clusterings(labels_true, labels_pred) @@ -920,7 +920,7 @@ def mutual_info_score(labels_true, labels_pred, *, contingency=None): + contingency_nm * log_outer ) mi = np.where(np.abs(mi) < np.finfo(mi.dtype).eps, 0.0, mi) - return np.clip(mi.sum(), 0.0, None) + return float(np.clip(mi.sum(), 0.0, None)) @validate_params( @@ -1008,17 +1008,14 @@ def adjusted_mutual_info_score( >>> from sklearn.metrics.cluster import adjusted_mutual_info_score >>> adjusted_mutual_info_score([0, 0, 1, 1], [0, 0, 1, 1]) - ... # doctest: +SKIP 1.0 >>> adjusted_mutual_info_score([0, 0, 1, 1], [1, 1, 0, 0]) - ... # doctest: +SKIP 1.0 If classes members are completely split across different clusters, the assignment is totally in-complete, hence the AMI is null:: >>> adjusted_mutual_info_score([0, 0, 0, 0], [0, 1, 2, 3]) - ... # doctest: +SKIP 0.0 """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) @@ -1053,7 +1050,7 @@ def adjusted_mutual_info_score( else: denominator = max(denominator, np.finfo("float64").eps) ami = (mi - emi) / denominator - return ami + return float(ami) @validate_params( @@ -1127,17 +1124,14 @@ def normalized_mutual_info_score( >>> from sklearn.metrics.cluster import normalized_mutual_info_score >>> normalized_mutual_info_score([0, 0, 1, 1], [0, 0, 1, 1]) - ... # doctest: +SKIP 1.0 >>> normalized_mutual_info_score([0, 0, 1, 1], [1, 1, 0, 0]) - ... # doctest: +SKIP 1.0 If classes members are completely split across different clusters, the assignment is totally in-complete, hence the NMI is null:: >>> normalized_mutual_info_score([0, 0, 0, 0], [0, 1, 2, 3]) - ... # doctest: +SKIP 0.0 """ labels_true, labels_pred = check_clusterings(labels_true, labels_pred) @@ -1168,7 +1162,7 @@ def normalized_mutual_info_score( h_true, h_pred = entropy(labels_true), entropy(labels_pred) normalizer = _generalized_average(h_true, h_pred, average_method) - return mi / normalizer + return float(mi / normalizer) @validate_params( @@ -1236,9 +1230,9 @@ def fowlkes_mallows_score(labels_true, labels_pred, *, sparse=False): >>> from sklearn.metrics.cluster import fowlkes_mallows_score >>> fowlkes_mallows_score([0, 0, 1, 1], [0, 0, 1, 1]) - np.float64(1.0) + 1.0 >>> fowlkes_mallows_score([0, 0, 1, 1], [1, 1, 0, 0]) - np.float64(1.0) + 1.0 If classes members are completely split across different clusters, the assignment is totally random, hence the FMI is null:: @@ -1254,7 +1248,7 @@ def fowlkes_mallows_score(labels_true, labels_pred, *, sparse=False): tk = np.dot(c.data, c.data) - n_samples pk = np.sum(np.asarray(c.sum(axis=0)).ravel() ** 2) - n_samples qk = np.sum(np.asarray(c.sum(axis=1)).ravel() ** 2) - n_samples - return np.sqrt(tk / pk) * np.sqrt(tk / qk) if tk != 0.0 else 0.0 + return float(np.sqrt(tk / pk) * np.sqrt(tk / qk)) if tk != 0.0 else 0.0 @validate_params( diff --git a/sklearn/metrics/cluster/_unsupervised.py b/sklearn/metrics/cluster/_unsupervised.py index ac6caf1e2382b..21dd22bc17a93 100644 --- a/sklearn/metrics/cluster/_unsupervised.py +++ b/sklearn/metrics/cluster/_unsupervised.py @@ -126,7 +126,7 @@ def silhouette_score( >>> X, y = make_blobs(random_state=42) >>> kmeans = KMeans(n_clusters=2, random_state=42) >>> silhouette_score(X, kmeans.fit_predict(X)) - np.float64(0.49...) + 0.49... """ if sample_size is not None: X, labels = check_X_y(X, labels, accept_sparse=["csc", "csr"]) @@ -136,7 +136,7 @@ def silhouette_score( X, labels = X[indices].T[indices].T, labels[indices] else: X, labels = X[indices], labels[indices] - return np.mean(silhouette_samples(X, labels, metric=metric, **kwds)) + return float(np.mean(silhouette_samples(X, labels, metric=metric, **kwds))) def _silhouette_reduce(D_chunk, start, labels, label_freqs): @@ -361,7 +361,7 @@ def calinski_harabasz_score(X, labels): >>> X, _ = make_blobs(random_state=0) >>> kmeans = KMeans(n_clusters=3, random_state=0,).fit(X) >>> calinski_harabasz_score(X, kmeans.labels_) - np.float64(114.8...) + 114.8... """ X, labels = check_X_y(X, labels) le = LabelEncoder() @@ -380,7 +380,7 @@ def calinski_harabasz_score(X, labels): extra_disp += len(cluster_k) * np.sum((mean_k - mean) ** 2) intra_disp += np.sum((cluster_k - mean_k) ** 2) - return ( + return float( 1.0 if intra_disp == 0.0 else extra_disp * (n_samples - n_labels) / (intra_disp * (n_labels - 1.0)) @@ -436,7 +436,7 @@ def davies_bouldin_score(X, labels): >>> X = [[0, 1], [1, 1], [3, 4]] >>> labels = [0, 0, 1] >>> davies_bouldin_score(X, labels) - np.float64(0.12...) + 0.12... """ X, labels = check_X_y(X, labels) le = LabelEncoder() @@ -461,4 +461,4 @@ def davies_bouldin_score(X, labels): centroid_distances[centroid_distances == 0] = np.inf combined_intra_dists = intra_dists[:, None] + intra_dists scores = np.max(combined_intra_dists / centroid_distances, axis=1) - return np.mean(scores) + return float(np.mean(scores)) diff --git a/sklearn/metrics/cluster/tests/test_common.py b/sklearn/metrics/cluster/tests/test_common.py index 0570f0ac2a0f1..a73670fbffce4 100644 --- a/sklearn/metrics/cluster/tests/test_common.py +++ b/sklearn/metrics/cluster/tests/test_common.py @@ -209,3 +209,26 @@ def test_inf_nan_input(metric_name, metric_func): with pytest.raises(ValueError, match=r"contains (NaN|infinity)"): for args in invalids: metric_func(*args) + + +@pytest.mark.parametrize("name", chain(SUPERVISED_METRICS, UNSUPERVISED_METRICS)) +def test_returned_value_consistency(name): + """Ensure that the returned values of all metrics are consistent. + + It can only be a float. It should not be a numpy float64 or float32. + """ + + rng = np.random.RandomState(0) + X = rng.randint(10, size=(20, 10)) + labels_true = rng.randint(0, 3, size=(20,)) + labels_pred = rng.randint(0, 3, size=(20,)) + + if name in SUPERVISED_METRICS: + metric = SUPERVISED_METRICS[name] + score = metric(labels_true, labels_pred) + else: + metric = UNSUPERVISED_METRICS[name] + score = metric(X, labels_pred) + + assert isinstance(score, float) + assert not isinstance(score, (np.float64, np.float32)) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 7e3758cd76654..9e8d0ce116394 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2235,3 +2235,34 @@ def _get_metric_kwargs_for_array_api_testing(metric, params): metric_kwargs_combinations = new_combinations return metric_kwargs_combinations + + +@pytest.mark.parametrize("name", sorted(ALL_METRICS)) +def test_returned_value_consistency(name): + """Ensure that the returned values of all metrics are consistent. + + It can either be a float, a numpy array, or a tuple of floats or numpy arrays. + It should not be a numpy float64 or float32. + """ + + rng = np.random.RandomState(0) + y_true = rng.randint(0, 2, size=(20,)) + y_pred = rng.randint(0, 2, size=(20,)) + + if name in METRICS_REQUIRE_POSITIVE_Y: + y_true, y_pred = _require_positive_targets(y_true, y_pred) + + if name in METRIC_UNDEFINED_BINARY: + y_true = rng.randint(0, 2, size=(20, 3)) + y_pred = rng.randint(0, 2, size=(20, 3)) + + metric = ALL_METRICS[name] + score = metric(y_true, y_pred) + + assert isinstance(score, (float, np.ndarray, tuple)) + assert not isinstance(score, (np.float64, np.float32)) + + if isinstance(score, tuple): + assert all(isinstance(v, float) for v in score) or all( + isinstance(v, np.ndarray) for v in score + )