-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Implemented "precision at recall k" and "recall at precision k" #20877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -42,6 +42,7 @@ | |||||||||
from ..exceptions import UndefinedMetricWarning | ||||||||||
|
||||||||||
from ._base import _check_pos_label_consistency | ||||||||||
from ._ranking import precision_recall_curve | ||||||||||
|
||||||||||
|
||||||||||
def _check_zero_division(zero_division): | ||||||||||
|
@@ -2649,3 +2650,151 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None): | |||||||||
raise | ||||||||||
y_true = np.array(y_true == pos_label, int) | ||||||||||
return np.average((y_true - y_prob) ** 2, weights=sample_weight) | ||||||||||
|
||||||||||
|
||||||||||
def recall_at_precision_k(y_true, y_prob, k, *, pos_label=None, sample_weight=None): | ||||||||||
"""Computes maximum recall for the thresholds when precision is greater | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make it fit on a single line:
|
||||||||||
than or equal to ``k`` | ||||||||||
|
||||||||||
Note: this implementation is restricted to the binary classification task. | ||||||||||
|
||||||||||
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of | ||||||||||
true positives and ``fp`` the number of false positives. The precision is | ||||||||||
intuitively the ability of the classifier not to label as positive a sample | ||||||||||
that is negative. | ||||||||||
|
||||||||||
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of | ||||||||||
true positives and ``fn`` the number of false negatives. The recall is | ||||||||||
intuitively the ability of the classifier to find all the positive samples. | ||||||||||
Comment on lines
+2661
to
+2668
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am thinking that we could avoid to repeat this description |
||||||||||
|
||||||||||
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will need to add a section in the user guide documentation. |
||||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
y_true : ndarray of shape (n_samples,) | ||||||||||
True binary labels. If labels are not either {-1, 1} or {0, 1}, then | ||||||||||
pos_label should be explicitly given. | ||||||||||
|
||||||||||
probas_pred : ndarray of shape (n_samples,) | ||||||||||
Target scores, can either be probability estimates of the positive | ||||||||||
class, or non-thresholded measure of decisions (as returned by | ||||||||||
`decision_function` on some classifiers). | ||||||||||
|
||||||||||
pos_label : int or str, default=None | ||||||||||
The label of the positive class. | ||||||||||
When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1}, | ||||||||||
``pos_label`` is set to 1, otherwise an error will be raised. | ||||||||||
Comment on lines
+2685
to
+2686
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
sample_weight : array-like of shape (n_samples,), default=None | ||||||||||
Sample weights. | ||||||||||
|
||||||||||
Returns | ||||||||||
------- | ||||||||||
recall_at_precision_k : float | ||||||||||
Maximum recall when for the thresholds when precision is greater | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is something wrong with this sentence due to the twice "when" |
||||||||||
than or equal to ``k`` for thresholds applied to the ``pos_label`` or | ||||||||||
to the label 1 if ``pos_label=None`` | ||||||||||
|
||||||||||
See Also | ||||||||||
-------- | ||||||||||
precision_recall_curve : Compute precision-recall curve. | ||||||||||
plot_precision_recall_curve : Plot Precision Recall Curve for binary | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not link to the |
||||||||||
classifiers. | ||||||||||
PrecisionRecallDisplay : Precision Recall visualization. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition, we should add both the |
||||||||||
|
||||||||||
Examples | ||||||||||
-------- | ||||||||||
>>> import numpy as np | ||||||||||
>>> from sklearn.metrics import recall_at_precision_k | ||||||||||
>>> y_true = np.array([0, 0, 1, 1, 1, 1]) | ||||||||||
>>> y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) | ||||||||||
>>> k = 0.75 | ||||||||||
>>> recall_at_precision_k(y_true, y_prob, k) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be better to take a threshold for which the score is not 1.0 |
||||||||||
1.0 | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should remove this blank line. |
||||||||||
""" | ||||||||||
|
||||||||||
precisions, recalls, _ = precision_recall_curve( | ||||||||||
y_true, y_prob, pos_label=pos_label, sample_weight=sample_weight | ||||||||||
) | ||||||||||
|
||||||||||
valid_positions = precisions >= k | ||||||||||
valid_recalls = recalls[valid_positions] | ||||||||||
value = 0.0 | ||||||||||
if valid_recalls.shape[0] > 0: | ||||||||||
value = np.max(valid_recalls) | ||||||||||
return value | ||||||||||
|
||||||||||
|
||||||||||
def precision_at_recall_k(y_true, y_prob, k, *, pos_label=None, sample_weight=None): | ||||||||||
"""Computes maximum precision for the thresholds when recall is greater | ||||||||||
than or equal to ``k`` | ||||||||||
|
||||||||||
Note: this implementation is restricted to the binary classification task. | ||||||||||
|
||||||||||
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of | ||||||||||
true positives and ``fp`` the number of false positives. The precision is | ||||||||||
intuitively the ability of the classifier not to label as positive a sample | ||||||||||
that is negative. | ||||||||||
|
||||||||||
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of | ||||||||||
true positives and ``fn`` the number of false negatives. The recall is | ||||||||||
intuitively the ability of the classifier to find all the positive samples. | ||||||||||
|
||||||||||
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`. | ||||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
y_true : ndarray of shape (n_samples,) | ||||||||||
True binary labels. If labels are not either {-1, 1} or {0, 1}, then | ||||||||||
pos_label should be explicitly given. | ||||||||||
|
||||||||||
probas_pred : ndarray of shape (n_samples,) | ||||||||||
Target scores, can either be probability estimates of the positive | ||||||||||
class, or non-thresholded measure of decisions (as returned by | ||||||||||
`decision_function` on some classifiers). | ||||||||||
|
||||||||||
pos_label : int or str, default=None | ||||||||||
The label of the positive class. | ||||||||||
When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1}, | ||||||||||
``pos_label`` is set to 1, otherwise an error will be raised. | ||||||||||
|
||||||||||
sample_weight : array-like of shape (n_samples,), default=None | ||||||||||
Sample weights. | ||||||||||
|
||||||||||
Returns | ||||||||||
------- | ||||||||||
precision_at_recall_k : float | ||||||||||
Maximum precision when for the thresholds when recall is greater | ||||||||||
than or equal to ``k`` for thresholds applied to the ``pos_label`` or | ||||||||||
to the label 1 if ``pos_label=None`` | ||||||||||
|
||||||||||
See Also | ||||||||||
-------- | ||||||||||
precision_recall_curve : Compute precision-recall curve. | ||||||||||
plot_precision_recall_curve : Plot Precision Recall Curve for binary | ||||||||||
classifiers. | ||||||||||
PrecisionRecallDisplay : Precision Recall visualization. | ||||||||||
|
||||||||||
Examples | ||||||||||
-------- | ||||||||||
>>> import numpy as np | ||||||||||
>>> from sklearn.metrics import precision_at_recall_k | ||||||||||
>>> y_true = np.array([0, 0, 1, 1, 1, 1]) | ||||||||||
>>> y_prob = np.array([0.1, 0.8, 0.9, 0.3, 1.0, 0.95]) | ||||||||||
>>> k = 0.8 | ||||||||||
>>> precision_at_recall_k(y_true, y_prob, k) | ||||||||||
0.8 | ||||||||||
|
||||||||||
""" | ||||||||||
|
||||||||||
precisions, recalls, _ = precision_recall_curve( | ||||||||||
y_true, y_prob, pos_label=pos_label, sample_weight=sample_weight | ||||||||||
) | ||||||||||
|
||||||||||
valid_positions = recalls >= k | ||||||||||
valid_precisions = precisions[valid_positions] | ||||||||||
value = 0.0 | ||||||||||
if valid_precisions.shape[0] > 0: | ||||||||||
value = np.max(valid_precisions) | ||||||||||
return value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it would be better to call it
max_precision_at_recall_k
andmax_recall_at_precision_k
to make it obvious that this is the maximum that is taken.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I hear "precision_at_recall_k" I think of a single number singled out from the precision recall curve. (given a line, if I constrain X=x_i, then Y=y_i). If I agree with that logic then I think the original name is better to use.