diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 7d8175a3b5046..2a28f009dd19a 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -623,6 +623,11 @@ Changelog amount of data. :pr:`20312` by :user:`Divyanshu Deoli `. +- |Feature| Added :func:`precision_at_recall_k` and :func:`recall_at_precision_k` + to calculate the 'maximum precision for thresholds where recall >= k' and 'maximum + precision for thresholds where precision >= k' respectively. + :pr:`20877` by :user:`Shubhraneel Pal `. + :mod:`sklearn.mixture` ...................... diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index a0b06a02ad6d1..4e10d644abf2d 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -35,6 +35,8 @@ from ._classification import zero_one_loss from ._classification import brier_score_loss from ._classification import multilabel_confusion_matrix +from ._classification import precision_at_recall_k +from ._classification import recall_at_precision_k from . import cluster from .cluster import adjusted_mutual_info_score @@ -171,4 +173,6 @@ "v_measure_score", "zero_one_loss", "brier_score_loss", + "precision_at_recall_k", + "recall_at_precision_k", ] diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 1a23ec01f4536..81d9138c126f4 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -42,6 +42,7 @@ from ..exceptions import UndefinedMetricWarning from ._base import _check_pos_label_consistency +from ._ranking import precision_recall_curve def _check_zero_division(zero_division): @@ -2649,3 +2650,151 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): raise y_true = np.array(y_true == pos_label, int) return np.average((y_true - y_prob) ** 2, weights=sample_weight) + + +def recall_at_precision_k(y_true, y_prob, k, *, pos_label=None, sample_weight=None): + """Computes maximum recall for the thresholds when precision is greater + than or equal to ``k`` + + Note: this implementation is restricted to the binary classification task. + + The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of + true positives and ``fp`` the number of false positives. The precision is + intuitively the ability of the classifier not to label as positive a sample + that is negative. + + The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. The recall is + intuitively the ability of the classifier to find all the positive samples. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + y_true : ndarray of shape (n_samples,) + True binary labels. If labels are not either {-1, 1} or {0, 1}, then + pos_label should be explicitly given. + + probas_pred : ndarray of shape (n_samples,) + Target scores, can either be probability estimates of the positive + class, or non-thresholded measure of decisions (as returned by + `decision_function` on some classifiers). + + pos_label : int or str, default=None + The label of the positive class. + When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1}, + ``pos_label`` is set to 1, otherwise an error will be raised. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + recall_at_precision_k : float + Maximum recall when for the thresholds when precision is greater + than or equal to ``k`` for thresholds applied to the ``pos_label`` or + to the label 1 if ``pos_label=None`` + + See Also + -------- + precision_recall_curve : Compute precision-recall curve. + plot_precision_recall_curve : Plot Precision Recall Curve for binary + classifiers. + PrecisionRecallDisplay : Precision Recall visualization. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import recall_at_precision_k + >>> y_true = np.array([0, 0, 1, 1, 1, 1]) + >>> y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) + >>> k = 0.75 + >>> recall_at_precision_k(y_true, y_prob, k) + 1.0 + + """ + + precisions, recalls, _ = precision_recall_curve( + y_true, y_prob, pos_label=pos_label, sample_weight=sample_weight + ) + + valid_positions = precisions >= k + valid_recalls = recalls[valid_positions] + value = 0.0 + if valid_recalls.shape[0] > 0: + value = np.max(valid_recalls) + return value + + +def precision_at_recall_k(y_true, y_prob, k, *, pos_label=None, sample_weight=None): + """Computes maximum precision for the thresholds when recall is greater + than or equal to ``k`` + + Note: this implementation is restricted to the binary classification task. + + The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of + true positives and ``fp`` the number of false positives. The precision is + intuitively the ability of the classifier not to label as positive a sample + that is negative. + + The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. The recall is + intuitively the ability of the classifier to find all the positive samples. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + y_true : ndarray of shape (n_samples,) + True binary labels. If labels are not either {-1, 1} or {0, 1}, then + pos_label should be explicitly given. + + probas_pred : ndarray of shape (n_samples,) + Target scores, can either be probability estimates of the positive + class, or non-thresholded measure of decisions (as returned by + `decision_function` on some classifiers). + + pos_label : int or str, default=None + The label of the positive class. + When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1}, + ``pos_label`` is set to 1, otherwise an error will be raised. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + precision_at_recall_k : float + Maximum precision when for the thresholds when recall is greater + than or equal to ``k`` for thresholds applied to the ``pos_label`` or + to the label 1 if ``pos_label=None`` + + See Also + -------- + precision_recall_curve : Compute precision-recall curve. + plot_precision_recall_curve : Plot Precision Recall Curve for binary + classifiers. + PrecisionRecallDisplay : Precision Recall visualization. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import precision_at_recall_k + >>> y_true = np.array([0, 0, 1, 1, 1, 1]) + >>> y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) + >>> k = 0.8 + >>> precision_at_recall_k(y_true, y_prob, k) + 0.8 + + """ + + precisions, recalls, _ = precision_recall_curve( + y_true, y_prob, pos_label=pos_label, sample_weight=sample_weight + ) + + valid_positions = recalls >= k + valid_precisions = precisions[valid_positions] + value = 0.0 + if valid_precisions.shape[0] > 0: + value = np.max(valid_precisions) + return value diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 4f29c127defb5..cee8a7d0cdc2b 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -42,6 +42,8 @@ from sklearn.metrics import zero_one_loss from sklearn.metrics import brier_score_loss from sklearn.metrics import multilabel_confusion_matrix +from sklearn.metrics import precision_at_recall_k +from sklearn.metrics import recall_at_precision_k from sklearn.metrics._classification import _check_targets from sklearn.exceptions import UndefinedMetricWarning @@ -2509,3 +2511,43 @@ def test_balanced_accuracy_score(y_true, y_pred): adjusted = balanced_accuracy_score(y_true, y_pred, adjusted=True) chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[0])) assert adjusted == (balanced - chance) / (1 - chance) + + +def test_precision_at_recall_k(): + y_true = np.array([0, 0, 1, 1, 1, 1]) + y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) + y_multi = np.array([0, 2, 1, 1, 1, 1]) + + assert_almost_equal(precision_at_recall_k(y_true, y_prob, 0.8), 0.8) + assert_almost_equal(precision_at_recall_k(y_true, y_prob, 0.6), 1) + assert_almost_equal(precision_at_recall_k(y_true * 2 - 1, y_prob, 0.8), 0.8) + + with pytest.raises(ValueError): + precision_at_recall_k(y_multi, y_prob, 0.8) + + assert_almost_equal(precision_at_recall_k(y_true, y_prob, 0.8, pos_label=1), 0.8) + + y_true = np.array([0]) + y_prob = np.array([0.4]) + with ignore_warnings(): + assert_almost_equal(precision_at_recall_k(y_true, y_prob, 0.1), 0) + + +def test_recall_at_precision_k(): + y_true = np.array([0, 0, 1, 1, 1, 1]) + y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) + y_multi = np.array([0, 2, 1, 1, 1, 1]) + + assert_almost_equal(recall_at_precision_k(y_true, y_prob, 1), 0.75) + assert_almost_equal(recall_at_precision_k(y_true, y_prob, 0.8), 1) + assert_almost_equal(recall_at_precision_k(y_true * 2 - 1, y_prob, 1), 0.75) + + with pytest.raises(ValueError): + recall_at_precision_k(y_multi, y_prob, 1) + + assert_almost_equal(recall_at_precision_k(y_true, y_prob, 1, pos_label=1), 0.75) + + y_true = np.array([0]) + y_prob = np.array([0.4]) + with ignore_warnings(): + assert_almost_equal(recall_at_precision_k(y_true, y_prob, 0.1), 0)