diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 8bc27194a63b5..fd3ce367a6ef5 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -227,6 +227,8 @@ Scoring string name Function 'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1' 'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1' 'jaccard' etc. :func:`metrics.jaccard_score` suffixes apply as with 'f1' +'specificity' etc. :func:`metrics.specificity_score` suffixes apply as with 'f1' +'npv' etc. :func:`metrics.npv_score` suffixes apply as with 'f1' 'roc_auc' :func:`metrics.roc_auc_score` 'roc_auc_ovr' :func:`metrics.roc_auc_score` 'roc_auc_ovo' :func:`metrics.roc_auc_score` @@ -536,6 +538,8 @@ Some also work in the multilabel case: precision_recall_fscore_support precision_score recall_score + specificity_score + npv_score roc_auc_score zero_one_loss d2_log_loss_score @@ -603,7 +607,6 @@ The :func:`accuracy_score` function computes the `accuracy `_, either the fraction (default) or the count (normalize=False) of correct predictions. - In multilabel classification, the function returns the subset accuracy. If the entire set of predicted labels for a sample strictly match with the true set of labels, then the subset accuracy is 1.0; otherwise it is 0.0. @@ -742,7 +745,7 @@ or *informedness*. * Our definition: [Mosley2013]_, [Kelleher2015]_ and [Guyon2015]_, where [Guyon2015]_ adopt the adjusted version to ensure that random predictions - have a score of :math:`0` and perfect predictions have a score of :math:`1`.. + have a score of :math:`0` and perfect predictions have a score of :math:`1`. * Class balanced accuracy as described in [Mosley2013]_: the minimum between the precision and the recall for each class is computed. Those values are then averaged over the total number of classes to get the balanced accuracy. @@ -855,6 +858,42 @@ false negatives and true positives as follows:: for an example of using a confusion matrix to classify text documents. +.. _tpr_fpr_tnr_fnr_score: + +TPR FPR TNR FNR score +--------------------- + +The :func:`tpr_fpr_tnr_fnr_score` function computes the true positive rate (TPR), +false positive rate (FPR), true negative rate (TNR) and false negative rate (FNR) +of predictions, based on the `confusion matrix `_. +The rates are defined as + +.. math:: + + \texttt{TPR} = \frac{TP}{P}} = \frac{TP}{TP + FN}} = 1 - FNR + + \texttt{FPR} = \frac{FP}{N}} = \frac{FP}{TN + FP}} = 1 - TNR + + \texttt{TNR} = \frac{TN}{N}} = \frac{TN}{TN + FP}} = 1 - FPR + + \texttt{FNR} = \frac{FN}{P}} = \frac{FN}{TP + FN}} = 1 - TPR + + >>> from sklearn.metrics import tpr_fpr_tnr_fnr_score + >>> y_true = [2, 0, 2, 2, 0, 1] + >>> y_pred = [0, 0, 2, 2, 0, 2] + >>> tpr_fpr_tnr_fnr_score(y_true, y_pred) + (array([1. , 0. , 0.66666667]), + array([0.25 , 0. , 0.33333333]), + array([0.75 , 1. , 0.66666667]), + array([0. , 1. , 0.33333333])) + +.. note:: + + * True positive rate (TPR) is also called recall, sensitivity, or hit rate. + * False positive rate (FPR) is also called fall-out. + * True negative rate (TNR) is also called specificity, or selectivity. + * false negative rate (FNR) is also called miss rate. + .. _classification_report: Classification report @@ -1006,6 +1045,18 @@ precision-recall curve as follows. :scale: 75 :align: center +Precision can also be referred to as the `positive predictive value (PPV) +`_, +e.g. in the context of bioscience. A closely related metric is +`negative predictive value (NPV) `_ +, implemented by the :func:`npv_score`. + +Recall can also be called the hit rate, or true positive rate (TPR). Especially +in biostatistics, it is also known as `sensitivity `_ +, which is related to `specificity `_. +In turn, specificity is also referred to as selectivity, or true negative rate (TNR), +and is implemented by the :func:`specificity_score`. + .. rubric:: Examples * See :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py` @@ -1044,10 +1095,10 @@ following table: +-------------------+------------------------------------------------+ | | Actual class (observation) | +-------------------+---------------------+--------------------------+ -| Predicted class | tp (true positive) | fp (false positive) | +| Predicted class | TP (true positive) | FP (false positive) | | (expectation) | Correct result | Unexpected result | | +---------------------+--------------------------+ -| | fn (false negative) | tn (true negative) | +| | FN (false negative) | TN (true negative) | | | Missing result | Correct absence of result| +-------------------+---------------------+--------------------------+ @@ -1117,10 +1168,9 @@ Here are some small examples in binary classification:: >>> average_precision_score(y_true, y_scores) 0.83... - - Multiclass and multilabel classification ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + In a multiclass and multilabel classification task, the notions of precision, recall, and F-measures can be applied to each label independently. There are a few ways to combine results across labels, @@ -1994,6 +2044,59 @@ the same does a lower Brier score loss always mean better calibration" and probability estimation." `_ Dagstuhl Seminar Proceedings. Schloss Dagstuhl-Leibniz-Zentrum für Informatik (2008). +.. _true_negatives_metrics: + +Specificity and negative predictive value (NPV) +----------------------------------------------- + +`Specificity `_ +(also called selectivity or true negative rate) and +`NPV `_ +are both ratios of true negatives to, respectively, actual negatives and +predicted negatives in a classification task. + +Binary classification +^^^^^^^^^^^^^^^^^^^^^ + +In a binary classification task, specificity and NPV are defined simply as + +..math:: + + \text{specificity} = \frac{TN}{N}} = \frac{TN}{TN + FP}} + + \text{NPV} = \frac{TN}{TN + FN}} + +Multiclass and multilabel classification +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In a multiclass or multilabel classification task, the notions of specificity +and NPV can be applied to each label independently. There are a few ways +to combine results across labels, specified by the ``average`` argument +to the :func:`specificity_score` and :func:`npv_score` functions, as described +:ref:`above `. + +To make this more explicit, consider the following examples: + >>> from sklearn.metrics import specificity_score + >>> from sklearn.metrics import npv_score + >>> y_true = [2, 0, 2, 2, 0, 1] + >>> y_pred = [0, 0, 2, 2, 0, 2] + >>> specificity_score(y_true, y_pred, average=None) + array([0.75 , 1. , 0.66666667]) + >>> npv_score(y_true, y_pred, average=None) + array([1. , 0.83333333, 0.66666667]) + >>> specificity_score(y_true, y_pred, average='macro') + 0.805... + >>> npv_score(y_true, y_pred, average='macro') + 0.83... + >>> specificity_score(y_true, y_pred, average='micro') + 0.83... + >>> npv_score(y_true, y_pred, average='micro') + 0.83... + >>> specificity_score(y_true, y_pred, average='weighted') + 0.75 + >>> npv_score(y_true, y_pred, average='weighted') + 0.805... + .. _class_likelihood_ratios: Class likelihood ratios diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index ce86525acc368..ba8dc2084be9c 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -21,9 +21,12 @@ log_loss, matthews_corrcoef, multilabel_confusion_matrix, + npv_score, precision_recall_fscore_support, precision_score, recall_score, + specificity_score, + tpr_fpr_tnr_fnr_score, zero_one_loss, ) from ._dist_metrics import DistanceMetric @@ -157,6 +160,7 @@ "nan_euclidean_distances", "ndcg_score", "normalized_mutual_info_score", + "npv_score", "pair_confusion_matrix", "pairwise_distances", "pairwise_distances_argmin", @@ -175,7 +179,9 @@ "root_mean_squared_log_error", "silhouette_samples", "silhouette_score", + "specificity_score", "top_k_accuracy_score", + "tpr_fpr_tnr_fnr_score", "v_measure_score", "zero_one_loss", ] diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 2e23c251af58a..f78b5fece5bd3 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -3552,3 +3552,594 @@ def d2_log_loss_score(y_true, y_pred, *, sample_weight=None, labels=None): ) return float(1 - (numerator / denominator)) + + +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"binary", "macro", "micro", "samples", "weighted"}), + None, + ], + "warn_for": [list, tuple, set], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + }, + prefer_skip_nested_validation=True, +) +def tpr_fpr_tnr_fnr_score( + y_true, + y_pred, + *, + labels=None, + pos_label=1, + average=None, + warn_for=("TPR", "FPR", "TNR", "FNR"), + sample_weight=None, + zero_division="warn", +): + """Compute the TPR, FPR, TNR, FNR for each class. + + The true positive rate (TPR) is the ratio `TP / (TP + FN)` where `TP` + is the number of true positives and `FN` the number of false negatives. + + The false positive rate (FPR) is the ratio `FP / (TN + FP)` where `TN` + is the number of true negatives and `FP` the number of false positives. + + The true negative rate (TNR) is the ratio `TN / (TN + FP)` where `TN` + is the number of true negatives and `FP` the number of false positives. + + The false negative rate (FNR) is the ratio `FN / (TP + FN)` where `TP` + is the number of true positives and `FN` the number of false negatives. + + If `pos_label is None` and in binary classification, this function + returns the true positive rate, false positive rate, true negative rate + and false negative rate if `average` is one of `"micro"`, `"macro"`, + `"weighted"` or `"samples"`. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.1 + + Parameters + ---------- + y_true : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Ground truth (correct) target values. + + y_pred : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Estimated targets as returned by a classifier. + + labels : list, default=None + The set of labels to include when `average != "binary"`, and their + order if `average is None`. Labels present in the data can be + excluded, for example to calculate a multiclass average ignoring a + majority negative class, while labels not present in the data will + result in 0 components in a macro average. For multilabel targets, + labels are column indices. By default, all labels in `y_true` and + `y_pred` are used in sorted order. + + pos_label : int, float, bool or str, default=1 + The class to report if `average='binary'` and the data is binary, + otherwise this parameter is ignored. + For multiclass or multilabel targets, set `labels=[pos_label]` and + `average != 'binary'` to report metrics for one label only. + + average : {"binary", "macro", "micro", "samples", "weighted"} or None, \ + default=None + If `None`, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + `"binary"`: + Only report results for the class specified by `pos_label`. + This is applicable only if targets (`y_{true,pred}`) are binary. + `"macro"`: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + `"micro"`: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + `"samples"`: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + `"weighted"`: + Calculate metrics for each label, and find their average weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance. + + warn_for : list, tuple or set, for internal use + This determines which warnings will be made in the case that this + function is being used to return only one of its metrics. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + zero_division : str or int, {"warn", 0, 1}, default="warn" + Sets the value to return when there is a zero division: + - TPR, FNR: when there are no positive labels + - FPR, TNR: when there are no negative labels + + If set to "warn", this acts as 0, but warnings are also raised. + + Returns + ------- + tpr : float or ndarray of shape (n_unique_labels,), dtype=np.float64 + The true positive rate (TPR) is the ratio `TP / (TP + FN)` where `TP` + is the number of true positives and `FN` the number of false negatives. + + fpr : float or ndarray of shape (n_unique_labels,), dtype=np.float64 + The false positive rate (FPR) is the ratio `FP / (TN + FP)` where `TN` + is the number of true negatives and `FP` the number of false positives. + + tnr : float or ndarray of shape (n_unique_labels,), dtype=np.float64 + The true negative rate (TNR) is the ratio `TN / (TN + FP)` where `TN` + is the number of true negatives and `FP` the number of false positives. + + fnr : float or ndarray of shape (n_unique_labels,), dtype=np.float64 + The false negative rate (FNR) is the ratio `FN / (TP + FN)` where `TP` + is the number of true positives and `FN` the number of false negatives. + + See Also + -------- + classification_report : A text report showing the key classification metrics. + precision_recall_fscore_support : The key classification metrics. + precision_score : Precision or positive predictive value (PPV). + recall_score : Recall, sensitivity, hit rate, or true positive rate (TPR). + specificity_score : Specificity, selectivity or true negative rate (TNR). + multilabel_confusion_matrix : Confusion matrices for each class or sample. + balanced_accuracy_score : Accuracy metric for imbalanced datasets. + npv_score : Negative predictive value (NPV). + + Notes + ----- + When `true positive + false negative == 0`, TPR, FNR are undefined; + When `true negative + false positive == 0`, FPR, TNR are undefined. + In such cases, by default the metric will be set to 0, + and `UndefinedMetricWarning` will be raised. This behavior can be + modified with `zero_division`. + + References + ---------- + .. [1] `Wikipedia entry for confusion matrix + `_ + + .. [2] `Discriminative Methods for Multi-labeled Classification Advances + in Knowledge Discovery and Data Mining (2004), pp. 22-30 by Shantanu + Godbole, Sunita Sarawagi + `_ + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import tpr_fpr_tnr_fnr_score + >>> y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig']) + >>> y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog']) + >>> tpr_fpr_tnr_fnr_score(y_true, y_pred, average='macro') + (0.33..., 0.33..., 0.66..., 0.66...) + >>> tpr_fpr_tnr_fnr_score(y_true, y_pred, average='micro') + (0.33..., 0.33..., 0.66..., 0.66...) + >>> tpr_fpr_tnr_fnr_score(y_true, y_pred, average='weighted') + (0.33..., 0.33..., 0.66..., 0.66...) + + It is possible to compute per-label FPR, FNR, TNR, TPR and + supports instead of averaging: + + >>> tpr_fpr_tnr_fnr_score(y_true, y_pred, average=None, + ... labels=['pig', 'dog', 'cat']) + (array([0., 0., 1.]), array([0.25, 0.5 , 0.25]), + array([0.75, 0.5 , 0.75]), array([1., 1., 0.])) + """ + _check_zero_division(zero_division) + + labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) + + samplewise = average == "samples" + MCM = multilabel_confusion_matrix( + y_true, + y_pred, + sample_weight=sample_weight, + labels=labels, + samplewise=samplewise, + ) + tp_sum = MCM[:, 1, 1] + fp_sum = MCM[:, 0, 1] + tn_sum = MCM[:, 0, 0] + fn_sum = MCM[:, 1, 0] + pos_sum = tp_sum + fn_sum + neg_sum = tn_sum + fp_sum + + if average == "micro": + tp_sum = np.array([tp_sum.sum()]) + fp_sum = np.array([fp_sum.sum()]) + tn_sum = np.array([tn_sum.sum()]) + fn_sum = np.array([fn_sum.sum()]) + pos_sum = np.array([pos_sum.sum()]) + neg_sum = np.array([neg_sum.sum()]) + + # Divide, and on zero-division, set scores and/or warn according to + # zero_division: + tpr = _prf_divide( + tp_sum, pos_sum, "TPR", "positives", average, warn_for, zero_division + ) + fpr = _prf_divide( + fp_sum, neg_sum, "FPR", "negatives", average, warn_for, zero_division + ) + tnr = _prf_divide( + tn_sum, neg_sum, "TNR", "negatives", average, warn_for, zero_division + ) + fnr = _prf_divide( + fn_sum, pos_sum, "FNR", "positives", average, warn_for, zero_division + ) + if average is None: + return tpr, fpr, tnr, fnr + + # Average the results + elif average == "weighted": + weights = pos_sum + if weights.sum() == 0: + zero_division_value = 0.0 if zero_division in ["warn", 0] else 1.0 + # TPR and FNR is zero_division if there are no positive labels + # FPR and TNR is zero_division if there are no negative labels + return ( + zero_division_value if pos_sum.sum() == 0 else 0, + zero_division_value if neg_sum.sum() == 0 else 0, + zero_division_value if neg_sum.sum() == 0 else 0, + zero_division_value if pos_sum.sum() == 0 else 0, + ) + elif average == "samples" and sample_weight is not None: + weights = sample_weight + else: + weights = None + assert average != "binary" or len(fpr) == 1, "Non-binary target." + tpr = float(np.average(tpr, weights=weights)) + fpr = float(np.average(fpr, weights=weights)) + tnr = float(np.average(tnr, weights=weights)) + fnr = float(np.average(fnr, weights=weights)) + return tpr, fpr, tnr, fnr + + +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"binary", "macro", "micro", "samples", "weighted"}), + None, + ], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + }, + prefer_skip_nested_validation=True, +) +def specificity_score( + y_true, + y_pred, + *, + labels=None, + pos_label=1, + average="binary", + sample_weight=None, + zero_division="warn", +): + """Compute the specificity, also known as the true negative rate (TNR). + + The specificity is the ratio `TN / (TN + FP)` where `TN` is the number + of true negatives and `FP` is the number of false positives. + The specificity is intuitively the ability of the classifier to find + all the negative samples. It is also called selectivity. + + The best value is 1 and the worst value is 0. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.1 + + Parameters + ---------- + y_true : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Ground truth (correct) target values. + + y_pred : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Estimated targets as returned by a classifier. + + labels : array-like, default=None + The set of labels to include when `average != "binary"`, and their + order if `average is None`. Labels present in the data can be + excluded, for example to calculate a multiclass average ignoring a + majority negative class, while labels not present in the data will + result in 0 components in a macro average. For multilabel targets, + labels are column indices. By default, all labels in `y_true` and + `y_pred` are used in sorted order. + + pos_label : int, float, bool or str, default=1 + The class to report if `average='binary'` and the data is binary, + otherwise this parameter is ignored. + For multiclass or multilabel targets, set `labels=[pos_label]` and + `average != 'binary'` to report metrics for one label only. + + average : {"binary", "macro", "micro", "samples", "weighted"} or None \ + default="binary" + This parameter is required for multiclass/multilabel targets. + If `None`, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + `"binary"`: + Only report results for the class specified by `pos_label`. + This is applicable only if targets (`y_{true,pred}`) are binary. + `"macro"`: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + `"micro"`: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + `"samples"`: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + `"weighted"`: + Calculate metrics for each label, and find their average weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division. If set to + "warn", this acts as 0, but warnings are also raised. + + Returns + ------- + specificity : float or ndarray of shape (n_unique_labels,), dtype=np.float64 + The specificity of the positive class in binary classification or + weighted average of the specificity of each class for the multiclass + task. Scalar is returned if averaging (i.e., when `average` is not `None`), + array - otherwise. + + See Also + -------- + precision_score : Precision or positive predictive value (PPV). + recall_score : Recall, sensitivity, hit rate, or true positive rate (TPR). + + Notes + ----- + When `true negative + false positive == 0`, specificity returns 0 and + raises `UndefinedMetricWarning`. This behavior can be modified with + `zero_division`. + + References + ---------- + .. [1] `Wikipedia entry for sensitivity and specificity + `_ + + Examples + -------- + >>> from sklearn.metrics import specificity_score + >>> y_true = [0, 1, 2, 0, 1, 2] + >>> y_pred = [0, 2, 1, 0, 0, 1] + >>> specificity_score(y_true, y_pred, average='macro') + 0.66... + >>> specificity_score(y_true, y_pred, average='micro') + 0.66... + >>> specificity_score(y_true, y_pred, average='weighted') + 0.66... + >>> specificity_score(y_true, y_pred, average=None) + array([0.75, 0.5 , 0.75]) + >>> y_true = [0, 0, 0, 0, 0, 0] + >>> specificity_score(y_true, y_pred, average=None) + array([0. , 0.66..., 0.83...]) + >>> specificity_score(y_true, y_pred, average=None, zero_division=1) + array([1. , 0.66..., 0.83...]) + """ + _, _, tnr, _ = tpr_fpr_tnr_fnr_score( + y_true, + y_pred, + labels=labels, + pos_label=pos_label, + average=average, + warn_for=("TNR",), + sample_weight=sample_weight, + zero_division=zero_division, + ) + return tnr + + +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"binary", "macro", "micro", "samples", "weighted"}), + None, + ], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + }, + prefer_skip_nested_validation=True, +) +def npv_score( + y_true, + y_pred, + labels=None, + pos_label=1, + average="binary", + sample_weight=None, + zero_division="warn", +): + """Compute the negative predictive value (NPV). + + The NPV is the ratio `TN / (TN + FN)` where `TN` is the number of true + negatives and `FN` is the number of false negatives. The NPV is intuitively + the ability of the classifier to mark the negative samples correctly. + + The best value is 1 and the worst value is 0. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.1 + + Parameters + ---------- + y_true : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Ground truth (correct) target values. + + y_pred : {array-like, label indicator array, sparse matrix} \ + of shape (n_samples,) + Estimated targets as returned by a classifier. + + labels : array-like, default=None + The set of labels to include when `average != "binary"`, and their + order if `average is None`. Labels present in the data can be + excluded, for example to calculate a multiclass average ignoring a + majority negative class, while labels not present in the data will + result in 0 components in a macro average. For multilabel targets, + labels are column indices. By default, all labels in `y_true` and + `y_pred` are used in sorted order. + + pos_label : int, float, bool or str, default=1 + The class to report if `average='binary'` and the data is binary, + otherwise this parameter is ignored. + For multiclass or multilabel targets, set `labels=[pos_label]` and + `average != 'binary'` to report metrics for one label only. + + average : {"binary", "macro", "micro", "samples", "weighted"}, None \ + default="binary" + This parameter is required for multiclass/multilabel targets. + If `None`, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + `"binary"`: + Only report results for the class specified by `pos_label`. + This is applicable only if targets (`y_{true,pred}`) are binary. + `"macro"`: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + `"micro"`: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + `"samples"`: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + `"weighted"`: + Calculate metrics for each label, and find their average weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division. If set to + "warn", this acts as 0, but warnings are also raised. + + Returns + ------- + NPV : float or ndarray of shape (n_unique_labels,), dtype=np.float64 + The negative predictive value of the positive class in binary + classification or weighted average of the NPV of each class for + the multiclass task. Scalar is returned if averaging (i.e., when + `average` is not `None`), array - otherwise. + + See Also + -------- + precision_score : Precision or positive predictive value (PPV). + recall_score : Recall, sensitivity, hit rate, or true positive rate (TPR). + + Notes + ----- + When `true negative + false negative == 0`, npv_score returns 0 and + raises `UndefinedMetricWarning`. This behavior can be modified with + `zero_division`. + + References + ---------- + .. [1] `Wikipedia entry for positive and negative predictive values + (PPV and NPV respectively) + `_ + + Examples + -------- + >>> from sklearn.metrics import npv_score + >>> y_true = [0, 1, 2, 0, 1, 2] + >>> y_pred = [0, 2, 1, 0, 0, 1] + >>> npv_score(y_true, y_pred, average='macro') + 0.70... + >>> npv_score(y_true, y_pred, average='micro') + 0.66... + >>> npv_score(y_true, y_pred, average='weighted') + 0.70... + >>> npv_score(y_true, y_pred, average=None) + array([1. , 0.5, 0.6]) + >>> y_pred = [0, 0, 0, 0, 0, 0] + >>> npv_score(y_true, y_pred, average=None) + array([0. , 0.66..., 0.66...]) + >>> npv_score(y_true, y_pred, average=None, zero_division=1) + array([1. , 0.66..., 0.66...]) + """ + _check_zero_division(zero_division) + + labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) + + # Calculate tn_sum, fn_sum, neg_calls_sum + samplewise = average == "samples" + MCM = multilabel_confusion_matrix( + y_true, + y_pred, + sample_weight=sample_weight, + labels=labels, + samplewise=samplewise, + ) + tp_sum = MCM[:, 1, 1] + tn_sum = MCM[:, 0, 0] + fn_sum = MCM[:, 1, 0] + pos_sum = tp_sum + fn_sum + neg_calls_sum = tn_sum + fn_sum + + if average == "micro": + tn_sum = np.array([tn_sum.sum()]) + neg_calls_sum = np.array([neg_calls_sum.sum()]) + + # Divide, and on zero-division, set scores and/or warn according to + # zero_division: + NPV = _prf_divide( + tn_sum, neg_calls_sum, "NPV", "negative call", average, "NPV", zero_division + ) + if average is None: + return NPV + # Average the results + elif average == "weighted": + weights = pos_sum + if weights.sum() == 0: + zero_division_value = 0.0 if zero_division in ["warn", 0] else 1.0 + # NPV is zero_division if there are no negative calls + return zero_division_value if neg_calls_sum.sum() == 0 else 0 + elif average == "samples": + weights = sample_weight + else: + weights = None + assert average != "binary" or len(NPV) == 1, "Non-binary target." + return float(np.average(NPV, weights=weights)) diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index 08e5a20187de7..2bd346a89959b 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -61,12 +61,14 @@ mean_squared_error, mean_squared_log_error, median_absolute_error, + npv_score, precision_score, r2_score, recall_score, roc_auc_score, root_mean_squared_error, root_mean_squared_log_error, + specificity_score, top_k_accuracy_score, ) from .cluster import ( @@ -918,6 +920,8 @@ def get_scorer_names(): ("recall", recall_score), ("f1", f1_score), ("jaccard", jaccard_score), + ("specificity", specificity_score), + ("npv", npv_score), ]: _SCORERS[name] = make_scorer(metric, average="binary") for average in ["macro", "micro", "samples", "weighted"]: diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index b67c91737960c..d088190c08c0a 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -30,9 +30,12 @@ make_scorer, matthews_corrcoef, multilabel_confusion_matrix, + npv_score, precision_recall_fscore_support, precision_score, recall_score, + specificity_score, + tpr_fpr_tnr_fnr_score, zero_one_loss, ) from sklearn.metrics._classification import _check_targets, d2_log_loss_score @@ -373,6 +376,274 @@ def test_precision_recall_f_ignored_labels(): assert recall_13(average=average) != recall_all(average=average) +def test_tpr_fpr_tnr_fnr_score_binary_averaged(): + # Test TPR, FPR, TNR, FNR scores for binary classification task + y_true, y_pred, _ = make_prediction(binary=True) + + # compute scores with default labels introspection + tprs, fprs, tnrs, fnrs = tpr_fpr_tnr_fnr_score(y_true, y_pred, average=None) + assert_array_almost_equal(tprs, [0.88, 0.68], 2) + assert_array_almost_equal(fprs, [0.32, 0.12], 2) + assert_array_almost_equal(tnrs, [0.68, 0.88], 2) + assert_array_almost_equal(fnrs, [0.12, 0.32], 2) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + + assert_array_almost_equal(tp / (tp + fn), 0.68, 2) + assert_array_almost_equal(fp / (tn + fp), 0.12, 2) + assert_array_almost_equal(tn / (tn + fp), 0.88, 2) + assert_array_almost_equal(fn / (tp + fn), 0.32, 2) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score(y_true, y_pred, average="macro") + assert tpr == np.mean(tprs) + assert fpr == np.mean(fprs) + assert tnr == np.mean(tnrs) + assert fnr == np.mean(fnrs) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score(y_true, y_pred, average="weighted") + support = np.bincount(y_true) + assert tpr == np.average(tprs, weights=support) + assert fpr == np.average(fprs, weights=support) + assert tnr == np.average(tnrs, weights=support) + assert fnr == np.average(fnrs, weights=support) + + +@ignore_warnings +def test_tpr_fpr_tnr_fnr_score_binary_single_class(): + # Test how the scores behave with a single positive or + # negative class + # Such a case may occur with non-stratified cross-validation + tprs, fprs, tnrs, fnrs = tpr_fpr_tnr_fnr_score([1, 1], [1, 1]) + assert 1.0 == tprs[0] + assert 0.0 == fprs[0] + assert 0.0 == tnrs[0] + assert 0.0 == fnrs[0] + + tprs, fprs, tnrs, fnrs = tpr_fpr_tnr_fnr_score([-1, -1], [-1, -1]) + assert 1.0 == tprs[0] + assert 0.0 == fprs[0] + assert 0.0 == tnrs[0] + assert 0.0 == fnrs[0] + + +@ignore_warnings +def test_tpr_fpr_tnr_fnr_score_extra_labels(): + # Test TPR, FPR, TNR, FNR handling of explicit additional (not in input) + # labels + y_true = [1, 3, 3, 2] + y_pred = [1, 1, 3, 2] + y_true_bin = label_binarize(y_true, classes=np.arange(5)) + y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) + data = [(y_true, y_pred), (y_true_bin, y_pred_bin)] + + for i, (y_true, y_pred) in enumerate(data): + # No averaging + tprs, fprs, tnrs, fnrs = tpr_fpr_tnr_fnr_score( + y_true, y_pred, labels=[0, 1, 2, 3, 4], average=None + ) + assert_array_almost_equal(tprs, [0.0, 1.0, 1.0, 0.5, 0.0], 2) + assert_array_almost_equal(fprs, [0.0, 0.33, 0.0, 0.0, 0.0], 2) + assert_array_almost_equal(tnrs, [1.0, 0.67, 1.0, 1.0, 1.0], 2) + assert_array_almost_equal(fnrs, [0.0, 0.0, 0.0, 0.5, 0.0], 2) + + # Macro average + scores = tpr_fpr_tnr_fnr_score( + y_true, y_pred, labels=[0, 1, 2, 3, 4], average="macro" + ) + assert_array_almost_equal(scores, [0.5, 0.07, 0.93, 0.1], 2) + + # Micro average + scores = tpr_fpr_tnr_fnr_score( + y_true, y_pred, labels=[0, 1, 2, 3, 4], average="micro" + ) + assert_array_almost_equal(scores, [0.75, 0.0625, 0.9375, 0.25], 4) + + # Further tests + for average in ["macro", "micro", "weighted", "samples"]: + if average in ["macro", "micro", "samples"] and i == 0: + continue + assert_almost_equal( + tpr_fpr_tnr_fnr_score( + y_true, y_pred, labels=[0, 1, 2, 3, 4], average=average + ), + tpr_fpr_tnr_fnr_score(y_true, y_pred, labels=None, average=average), + ) + + # Error when introducing invalid label in multilabel case + for average in [None, "macro", "micro", "samples"]: + err_msg = ( + r"All labels must be in \[0, n labels\) for multilabel targets\." + " Got 5 > 4" + ) + with pytest.raises(ValueError, match=err_msg): + tpr_fpr_tnr_fnr_score( + y_true_bin, y_pred_bin, labels=np.arange(6), average=average + ) + err_msg = ( + r"All labels must be in \[0, n labels\) for multilabel targets\." + " Got -1 < 0" + ) + with pytest.raises(ValueError, match=err_msg): + tpr_fpr_tnr_fnr_score( + y_true_bin, y_pred_bin, labels=np.arange(-1, 4), average=average + ) + + +@ignore_warnings +def test_tpr_fpr_tnr_fnr_score_ignored_labels(): + # Test TPR, FPR, TNR, FNR handling of a subset of labels + y_true = [1, 1, 2, 3] + y_pred = [1, 3, 3, 3] + y_true_bin = label_binarize(y_true, classes=np.arange(5)) + y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) + data = [(y_true, y_pred), (y_true_bin, y_pred_bin)] + + for i, (y_true, y_pred) in enumerate(data): + scores_13 = partial(tpr_fpr_tnr_fnr_score, y_true, y_pred, labels=[1, 3]) + scores_all = partial(tpr_fpr_tnr_fnr_score, y_true, y_pred, labels=None) + + assert_array_almost_equal( + ([0.5, 1.0], [0.0, 0.67], [1.0, 0.33], [0.5, 0.0]), + scores_13(average=None), + 2, + ) + assert_almost_equal([0.75, 0.33, 0.67, 0.25], scores_13(average="macro"), 2) + assert_almost_equal([0.67, 0.4, 0.6, 0.33], scores_13(average="micro"), 2) + assert_almost_equal([0.67, 0.22, 0.78, 0.33], scores_13(average="weighted"), 2) + + # ensure the above were meaningful tests: + for average in ["macro", "weighted", "micro"]: + assert scores_13(average=average) != scores_all(average=average) + + +def test_tpr_fpr_tnr_fnr_score_multiclass(): + # Test TPR, FPR, TNR, FNR scores for multiclass classification task + y_true, y_pred, _ = make_prediction(binary=False) + + # compute scores with default labels introspection + tprs, fprs, tnrs, fnrs = tpr_fpr_tnr_fnr_score(y_true, y_pred, average=None) + assert_array_almost_equal(tprs, [0.79, 0.1, 0.9], 2) + assert_array_almost_equal(fprs, [0.08, 0.14, 0.45], 2) + assert_array_almost_equal(tnrs, [0.92, 0.86, 0.55], 2) + assert_array_almost_equal(fnrs, [0.21, 0.9, 0.1], 2) + + # averaging tests + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score(y_true, y_pred, average="micro") + assert_almost_equal(tpr, 0.53, 2) + assert_almost_equal(fpr, 0.23, 2) + assert_almost_equal(tnr, 0.77, 2) + assert_almost_equal(fnr, 0.47, 2) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score(y_true, y_pred, average="macro") + assert_almost_equal(tpr, 0.6, 2) + assert_almost_equal(fpr, 0.22, 2) + assert_almost_equal(tnr, 0.78, 2) + assert_almost_equal(fnr, 0.4, 2) + + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score(y_true, y_pred, average="weighted") + assert_almost_equal(tpr, 0.53, 2) + assert_almost_equal(fpr, 0.2, 2) + assert_almost_equal(tnr, 0.8, 2) + assert_almost_equal(fnr, 0.47, 2) + + err_msg = ( + "Samplewise metrics are not available outside of multilabel" + r" classification\." + ) + with pytest.raises(ValueError, match=err_msg): + tpr_fpr_tnr_fnr_score(y_true, y_pred, average="samples") + + # same prediction but with explicit label ordering + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score( + y_true, y_pred, labels=[0, 2, 1], average=None + ) + assert_array_almost_equal(tpr, [0.79, 0.9, 0.1], 2) + assert_array_almost_equal(fpr, [0.08, 0.45, 0.14], 2) + assert_array_almost_equal(tnr, [0.92, 0.55, 0.86], 2) + assert_array_almost_equal(fnr, [0.21, 0.1, 0.9], 2) + + +@pytest.mark.parametrize("zero_division", ["warn", 0, 1]) +def test_tpr_fpr_tnr_fnr_score_with_an_empty_prediction(zero_division): + y_true = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 1, 1, 0]]) + y_pred = np.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 0]]) + + msg = ( + "Fnr is ill-defined and being set to 0.0 in labels with no positives samples." + " Use `zero_division` parameter to control this behavior." + ) + + zero_division_value = 1.0 if zero_division == 1.0 else 0.0 + + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score( + y_true, y_pred, average=None, zero_division=zero_division + ) + if zero_division == "warn": + assert str(record.pop().message) == msg + else: + assert len(record) == 0 + assert_array_almost_equal(tpr, [0.0, 0.5, 1.0, zero_division_value], 2) + assert_array_almost_equal(fpr, [0.0, 0.0, 0.0, 1 / 3.0], 2) + assert_array_almost_equal(tnr, [1.0, 1.0, 1.0, 2 / 3.0], 2) + assert_array_almost_equal(fnr, [1.0, 0.5, 0.0, zero_division_value], 2) + + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score( + y_true, y_pred, average="macro", zero_division=zero_division + ) + if zero_division == "warn": + assert str(record.pop().message) == msg + else: + assert len(record) == 0 + assert_almost_equal(tpr, 0.625 if zero_division_value else 0.375) + assert_almost_equal(fpr, 1 / 3.0 / 4.0) + assert_almost_equal(tnr, 0.91666, 5) + assert_almost_equal(fnr, 0.625 if zero_division_value else 0.375) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score( + y_true, y_pred, average="micro", zero_division=zero_division + ) + assert_almost_equal(tpr, 0.5) + assert_almost_equal(fpr, 0.125) + assert_almost_equal(tnr, 0.875) + assert_almost_equal(fnr, 0.5) + + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score( + y_true, y_pred, average="weighted", zero_division=zero_division + ) + if zero_division == "warn": + assert str(record.pop().message) == msg + else: + assert len(record) == 0 + assert_almost_equal(tpr, 0.5) + assert_almost_equal(fpr, 0) + assert_almost_equal(tnr, 1.0) + assert_almost_equal(fnr, 0.5) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + tpr, fpr, tnr, fnr = tpr_fpr_tnr_fnr_score( + y_true, + y_pred, + average="samples", + sample_weight=[1, 1, 2], + zero_division=zero_division, + ) + assert_almost_equal(tpr, 0.5) + assert_almost_equal(fpr, 0.08333, 5) + assert_almost_equal(tnr, 0.91666, 5) + assert_almost_equal(fnr, 0.5) + + def test_average_precision_score_non_binary_class(): """Test multiclass-multiouptut for `average_precision_score`.""" y_true = np.array( @@ -1246,6 +1517,7 @@ def test_zero_precision_recall(): assert_almost_equal(precision_score(y_true, y_pred, average="macro"), 0.0, 2) assert_almost_equal(recall_score(y_true, y_pred, average="macro"), 0.0, 2) assert_almost_equal(f1_score(y_true, y_pred, average="macro"), 0.0, 2) + assert_almost_equal(specificity_score(y_true, y_pred, average="macro"), 0.5, 2) finally: np.seterr(**old_error_settings) @@ -2444,6 +2716,203 @@ def test_fscore_warnings(zero_division): assert len(record) == 0 +@pytest.mark.parametrize("zero_division", ["warn", 0, 1]) +def test_specificity_warnings(zero_division): + with warnings.catch_warnings(): + warnings.simplefilter("error") + specificity_score( + np.array([[0, 0], [0, 0]]), + np.array([[1, 1], [1, 1]]), + average="micro", + zero_division=zero_division, + ) + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + specificity_score( + np.array([[1, 1], [1, 1]]), + np.array([[0, 0], [0, 0]]), + average="micro", + zero_division=zero_division, + ) + if zero_division == "warn": + assert ( + str(record.pop().message) == "Tnr is ill-defined and " + "being set to 0.0 due to no negatives samples." + " Use `zero_division` parameter to control" + " this behavior." + ) + else: + assert len(record) == 0 + + specificity_score([1, 1], [1, 1]) + assert ( + str(record.pop().message) == "Tnr is ill-defined and " + "being set to 0.0 due to no negatives samples." + " Use `zero_division` parameter to control" + " this behavior." + ) + + +def test_npv_binary_averaged(): + # Test NPV score for binary classification task + y_true, y_pred, _ = make_prediction(binary=True) + + # compute scores with default labels + npv_none = npv_score(y_true, y_pred, average=None) + assert_array_almost_equal(npv_none, [0.85, 0.73], 2) + + npv_macro = npv_score(y_true, y_pred, average="macro") + assert npv_macro == np.mean(npv_none) + + npw_weighted = npv_score(y_true, y_pred, average="weighted") + support = np.bincount(y_true) + assert npw_weighted == np.average(npv_none, weights=support) + + +@ignore_warnings +def test_npv_binary_single_class(): + # Test how the NPV score behaves with a single positive or + # negative class + # Such a case may occur with non-stratified cross-validation + assert 0.0 == npv_score([1, 1], [1, 1]) + assert 1.0 == npv_score([-1, -1], [-1, -1]) + + +@ignore_warnings +def test_npv_extra_labels(): + # Test NPV handling of explicit additional (not in input) labels + y_true = [1, 3, 3, 2] + y_pred = [1, 1, 3, 2] + y_true_bin = label_binarize(y_true, classes=np.arange(5)) + y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) + data = [(y_true, y_pred), (y_true_bin, y_pred_bin)] + + for i, (y_true, y_pred) in enumerate(data): + print(i) + # No averaging + npvs = npv_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average=None) + assert_array_almost_equal(npvs, [1.0, 1.0, 1.0, 0.67, 1.0], 2) + + # Macro average + npv = npv_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average="macro") + assert_almost_equal(npv, 0.93, 2) + + # Micro average + npv = npv_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average="micro") + assert_almost_equal(npv, 0.9375, 4) + + # Further tests + for average in ["macro", "micro", "weighted", "samples"]: + if average in ["macro", "micro", "samples"] and i == 0: + continue + assert_almost_equal( + npv_score(y_true, y_pred, labels=[0, 1, 2, 3, 4], average=average), + npv_score(y_true, y_pred, labels=None, average=average), + ) + + # Error when introducing invalid label in multilabel case + for average in [None, "macro", "micro", "samples"]: + err_msg = ( + r"All labels must be in \[0, n labels\) for multilabel targets\." + " Got 5 > 4" + ) + with pytest.raises(ValueError, match=err_msg): + npv_score(y_true_bin, y_pred_bin, labels=np.arange(6), average=average) + err_msg = ( + r"All labels must be in \[0, n labels\) for multilabel targets\." + " Got -1 < 0" + ) + with pytest.raises(ValueError, match=err_msg): + npv_score(y_true_bin, y_pred_bin, labels=np.arange(-1, 4), average=average) + + +@ignore_warnings +def test_npv_ignored_labels(): + # Test NPV handling of a subset of labels + y_true = [1, 1, 2, 3] + y_pred = [1, 3, 3, 3] + y_true_bin = label_binarize(y_true, classes=np.arange(5)) + y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) + data = [(y_true, y_pred), (y_true_bin, y_pred_bin)] + + for i, (y_true, y_pred) in enumerate(data): + npv_13 = partial(npv_score, y_true, y_pred, labels=[1, 3]) + npv_all = partial(npv_score, y_true, y_pred, labels=None) + + assert_almost_equal([0.67, 1.0], npv_13(average=None), 2) + assert_almost_equal(0.83, npv_13(average="macro"), 2) + assert_almost_equal(0.75, npv_13(average="micro"), 2) + assert_almost_equal(0.78, npv_13(average="weighted"), 2) + + # ensure the above were meaningful tests: + for average in ["macro", "weighted", "micro"]: + if average == "micro" and i == 0: + continue + assert npv_13(average=average) != npv_all(average=average) + + +def test_npv_multiclass(): + # Test NPV score for multiclass classification task + y_true, y_pred, _ = make_prediction(binary=False) + + # compute scores with default labels + assert_array_almost_equal( + npv_score(y_true, y_pred, average=None), [0.9, 0.58, 0.94], 2 + ) + + # averaging tests + assert_array_almost_equal(npv_score(y_true, y_pred, average="micro"), 0.77, 2) + + assert_array_almost_equal(npv_score(y_true, y_pred, average="macro"), 0.81, 2) + + assert_array_almost_equal(npv_score(y_true, y_pred, average="weighted"), 0.78, 2) + + err_msg = ( + "Samplewise metrics are not available outside of multilabel" + r" classification\." + ) + with pytest.raises(ValueError, match=err_msg): + npv_score(y_true, y_pred, average="samples") + + # same prediction but with explicit label ordering + assert_array_almost_equal( + npv_score(y_true, y_pred, labels=[0, 2, 1], average=None), [0.9, 0.94, 0.58], 2 + ) + + +@pytest.mark.parametrize("zero_division", ["warn", 0, 1]) +def test_npv_warnings(zero_division): + with warnings.catch_warnings(): + warnings.simplefilter("error") + npv_score( + np.array([[1, 1], [1, 1]]), + np.array([[0, 0], [0, 0]]), + average="micro", + zero_division=zero_division, + ) + + msg = ( + "Npv is ill-defined and being set to 0.0 due to no negative call samples." + " Use `zero_division` parameter to control this behavior." + ) + + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + npv_score( + np.array([[0, 0], [0, 0]]), + np.array([[1, 1], [1, 1]]), + average="micro", + zero_division=zero_division, + ) + if zero_division == "warn": + assert str(record[-1].message) == msg + else: + assert len(record) == 0 + + with pytest.warns(UndefinedMetricWarning, match=msg): + npv_score([1, 1], [1, 1]) + + def test_prf_average_binary_data_non_binary(): # Error if user does not explicitly set non-binary average mode y_true_mc = [1, 2, 3, 3] diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 6e6950b1d2eff..e1c85a3fc84f9 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -44,6 +44,7 @@ median_absolute_error, multilabel_confusion_matrix, ndcg_score, + npv_score, precision_recall_curve, precision_score, r2_score, @@ -52,7 +53,9 @@ roc_curve, root_mean_squared_error, root_mean_squared_log_error, + specificity_score, top_k_accuracy_score, + tpr_fpr_tnr_fnr_score, zero_one_loss, ) from sklearn.metrics._base import _average_binary_score @@ -171,30 +174,50 @@ "f2_score": partial(fbeta_score, beta=2), "f0.5_score": partial(fbeta_score, beta=0.5), "matthews_corrcoef_score": matthews_corrcoef, + "tpr_fpr_tnr_fnr_score": tpr_fpr_tnr_fnr_score, + "binary_tpr_fpr_tnr_fnr_score": partial(tpr_fpr_tnr_fnr_score, average="binary"), + "specificity_score": specificity_score, + "binary_specificity_score": partial(specificity_score, average="binary"), + "npv_score": npv_score, + "binary_npv_score": partial(npv_score, average="binary"), "weighted_f0.5_score": partial(fbeta_score, average="weighted", beta=0.5), "weighted_f1_score": partial(f1_score, average="weighted"), "weighted_f2_score": partial(fbeta_score, average="weighted", beta=2), "weighted_precision_score": partial(precision_score, average="weighted"), "weighted_recall_score": partial(recall_score, average="weighted"), "weighted_jaccard_score": partial(jaccard_score, average="weighted"), + "weighted_tpr_fpr_tnr_fnr_score": partial( + tpr_fpr_tnr_fnr_score, average="weighted" + ), + "weighted_specificity_score": partial(specificity_score, average="weighted"), + "weighted_npv_score": partial(npv_score, average="weighted"), "micro_f0.5_score": partial(fbeta_score, average="micro", beta=0.5), "micro_f1_score": partial(f1_score, average="micro"), "micro_f2_score": partial(fbeta_score, average="micro", beta=2), "micro_precision_score": partial(precision_score, average="micro"), "micro_recall_score": partial(recall_score, average="micro"), "micro_jaccard_score": partial(jaccard_score, average="micro"), + "micro_tpr_fpr_tnr_fnr_score": partial(tpr_fpr_tnr_fnr_score, average="micro"), + "micro_specificity_score": partial(specificity_score, average="micro"), + "micro_npv_score": partial(npv_score, average="micro"), "macro_f0.5_score": partial(fbeta_score, average="macro", beta=0.5), "macro_f1_score": partial(f1_score, average="macro"), "macro_f2_score": partial(fbeta_score, average="macro", beta=2), "macro_precision_score": partial(precision_score, average="macro"), "macro_recall_score": partial(recall_score, average="macro"), "macro_jaccard_score": partial(jaccard_score, average="macro"), + "macro_tpr_fpr_tnr_fnr_score": partial(tpr_fpr_tnr_fnr_score, average="macro"), + "macro_specificity_score": partial(specificity_score, average="macro"), + "macro_npv_score": partial(npv_score, average="macro"), "samples_f0.5_score": partial(fbeta_score, average="samples", beta=0.5), "samples_f1_score": partial(f1_score, average="samples"), "samples_f2_score": partial(fbeta_score, average="samples", beta=2), "samples_precision_score": partial(precision_score, average="samples"), "samples_recall_score": partial(recall_score, average="samples"), "samples_jaccard_score": partial(jaccard_score, average="samples"), + "samples_tpr_fpr_tnr_fnr_score": partial(tpr_fpr_tnr_fnr_score, average="samples"), + "samples_specificity_score": partial(specificity_score, average="samples"), + "samples_npv_score": partial(npv_score, average="samples"), "cohen_kappa_score": cohen_kappa_score, } @@ -293,6 +316,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "samples_precision_score", "samples_recall_score", "samples_jaccard_score", + "samples_tpr_fpr_tnr_fnr_score", + "samples_specificity_score", + "samples_npv_score", "coverage_error", "unnormalized_multilabel_confusion_matrix_sample", "label_ranking_loss", @@ -316,6 +342,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "f1_score", "f2_score", "f0.5_score", + "tpr_fpr_tnr_fnr_score", + "specificity_score", + "npv_score", + "binary_tpr_fpr_tnr_fnr_score", + "binary_specificity_score", + "binary_npv_score", # curves "roc_curve", "precision_recall_curve", @@ -335,6 +367,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "f2_score", "f0.5_score", "jaccard_score", + "specificity_score", + "npv_score", } # Threshold-based metrics with an "average" argument @@ -360,6 +394,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "weighted_average_precision_score", "micro_average_precision_score", "samples_average_precision_score", + "tpr_fpr_tnr_fnr_score", + "specificity_score", + "npv_score", } # Metrics with a "labels" argument @@ -377,24 +414,36 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "f2_score", "f0.5_score", "jaccard_score", + "tpr_fpr_tnr_fnr_score", + "specificity_score", + "npv_score", "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_recall_score", "weighted_jaccard_score", + "weighted_tpr_fpr_tnr_fnr_score", + "weighted_specificity_score", + "weighted_npv_score", "micro_f0.5_score", "micro_f1_score", "micro_f2_score", "micro_precision_score", "micro_recall_score", "micro_jaccard_score", + "micro_tpr_fpr_tnr_fnr_score", + "micro_specificity_score", + "micro_npv_score", "macro_f0.5_score", "macro_f1_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", "macro_jaccard_score", + "macro_tpr_fpr_tnr_fnr_score", + "macro_specificity_score", + "macro_npv_score", "unnormalized_multilabel_confusion_matrix", "unnormalized_multilabel_confusion_matrix_sample", "cohen_kappa_score", @@ -440,18 +489,27 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "weighted_precision_score", "weighted_recall_score", "weighted_jaccard_score", + "weighted_tpr_fpr_tnr_fnr_score", + "weighted_specificity_score", + "weighted_npv_score", "macro_f0.5_score", "macro_f1_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", "macro_jaccard_score", + "macro_tpr_fpr_tnr_fnr_score", + "macro_specificity_score", + "macro_npv_score", "micro_f0.5_score", "micro_f1_score", "micro_f2_score", "micro_precision_score", "micro_recall_score", "micro_jaccard_score", + "micro_tpr_fpr_tnr_fnr_score", + "micro_specificity_score", + "micro_npv_score", "unnormalized_multilabel_confusion_matrix", "samples_f0.5_score", "samples_f1_score", @@ -459,6 +517,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "samples_precision_score", "samples_recall_score", "samples_jaccard_score", + "samples_tpr_fpr_tnr_fnr_score", + "samples_specificity_score", + "samples_npv_score", } # Regression metrics with "multioutput-continuous" format support @@ -490,7 +551,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "jaccard_score", "samples_jaccard_score", "f1_score", - "micro_f1_score", "macro_f1_score", "weighted_recall_score", "mean_squared_log_error", @@ -502,6 +562,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "micro_f2_score", "micro_precision_score", "micro_recall_score", + "micro_tpr_fpr_tnr_fnr_score", + "micro_specificity_score", + "micro_npv_score", "matthews_corrcoef_score", "mean_absolute_error", "mean_squared_error", @@ -529,16 +592,25 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "recall_score", "f2_score", "f0.5_score", + "tpr_fpr_tnr_fnr_score", + "specificity_score", + "npv_score", "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_jaccard_score", + "weighted_tpr_fpr_tnr_fnr_score", + "weighted_specificity_score", + "weighted_npv_score", "unnormalized_multilabel_confusion_matrix", "macro_f0.5_score", "macro_f2_score", "macro_precision_score", "macro_recall_score", + "macro_tpr_fpr_tnr_fnr_score", + "macro_specificity_score", + "macro_npv_score", "hinge_loss", "mean_gamma_deviance", "mean_poisson_deviance", diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 672ed8ae7eecc..1d9702d50bc2c 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -34,10 +34,12 @@ log_loss, make_scorer, matthews_corrcoef, + npv_score, precision_score, r2_score, recall_score, roc_auc_score, + specificity_score, top_k_accuracy_score, ) from sklearn.metrics import cluster as cluster_module @@ -115,6 +117,14 @@ "matthews_corrcoef", "positive_likelihood_ratio", "neg_negative_likelihood_ratio", + "specificity", + "specificity_weighted", + "specificity_macro", + "specificity_micro", + "npv", + "npv_weighted", + "npv_macro", + "npv_micro", ] # All supervised cluster scorers (They behave like classification metric) @@ -135,6 +145,8 @@ "recall_samples", "f1_samples", "jaccard_samples", + "specificity_samples", + "npv_samples", ] REQUIRE_POSITIVE_Y_SCORERS = ["neg_mean_poisson_deviance", "neg_mean_gamma_deviance"] @@ -378,6 +390,8 @@ def test_check_scoring_gridsearchcv(): ("jaccard_micro", partial(jaccard_score, average="micro")), ("top_k_accuracy", top_k_accuracy_score), ("matthews_corrcoef", matthews_corrcoef), + ("specificity", specificity_score), + ("npv", npv_score), ], ) def test_classification_binary_scores(scorer_name, metric): @@ -410,6 +424,12 @@ def test_classification_binary_scores(scorer_name, metric): ("jaccard_weighted", partial(jaccard_score, average="weighted")), ("jaccard_macro", partial(jaccard_score, average="macro")), ("jaccard_micro", partial(jaccard_score, average="micro")), + ("specificity_weighted", partial(specificity_score, average="weighted")), + ("specificity_macro", partial(specificity_score, average="macro")), + ("specificity_micro", partial(specificity_score, average="micro")), + ("npv_weighted", partial(npv_score, average="weighted")), + ("npv_macro", partial(npv_score, average="macro")), + ("npv_micro", partial(npv_score, average="micro")), ], ) def test_classification_multiclass_scores(scorer_name, metric): @@ -1130,7 +1150,15 @@ def test_brier_score_loss_pos_label(string_labeled_classification_problem): @pytest.mark.parametrize( - "score_func", [f1_score, precision_score, recall_score, jaccard_score] + "score_func", + [ + f1_score, + precision_score, + recall_score, + jaccard_score, + specificity_score, + npv_score, + ], ) def test_non_symmetric_metric_pos_label( score_func, string_labeled_classification_problem diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 566b8f535c9cb..af350d6b3ae8c 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -18,7 +18,12 @@ Ridge, SGDClassifier, ) -from sklearn.metrics import precision_score, recall_score +from sklearn.metrics import ( + npv_score, + precision_score, + recall_score, + specificity_score, +) from sklearn.model_selection import GridSearchCV, cross_val_score from sklearn.multiclass import ( OneVsOneClassifier, @@ -339,7 +344,13 @@ def test_ovr_fit_predict_svc(): def test_ovr_multilabel_dataset(): base_clf = MultinomialNB(alpha=1) - for au, prec, recall in zip((True, False), (0.51, 0.66), (0.51, 0.80)): + for au, prec, recall, specificity, npv in zip( + (True, False), + (0.51, 0.66), + (0.51, 0.80), + (0.66, 0.71), + (0.66, 0.84), + ): X, Y = datasets.make_multilabel_classification( n_samples=100, n_features=20, @@ -361,6 +372,10 @@ def test_ovr_multilabel_dataset(): assert_almost_equal( recall_score(Y_test, Y_pred, average="micro"), recall, decimal=2 ) + assert_almost_equal( + specificity_score(Y_test, Y_pred, average="micro"), specificity, decimal=2 + ) + assert_almost_equal(npv_score(Y_test, Y_pred, average="micro"), npv, decimal=2) def test_ovr_multilabel_predict_proba():