Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,14 @@ Changelog
:pr:`18723` by :user:`Sahil Gupta <sahilgupta2105>` and
:pr:`24677` by :user:`Ashwin Mathur <awinml>`.

- |Enhancement| A new parameter `drop_intermediate` was added to
:func:`metrics.precision_recall_curve`,
:func:`metrics.PrecisionRecallDisplay.from_estimator`,
:func:`metrics.PrecisionRecallDisplay.from_predictions`,
which drops some suboptimal thresholds to create lighter precision-recall
curves.
:pr:`24668` by :user:`dberenbaum`.

- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
not normalized, instead of actually normalizing them in the metric. Starting from
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.
Expand Down
23 changes: 22 additions & 1 deletion sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def from_estimator(
*,
sample_weight=None,
pos_label=None,
drop_intermediate=False,
response_method="auto",
name=None,
ax=None,
Expand Down Expand Up @@ -213,6 +214,13 @@ def from_estimator(
precision and recall metrics. By default, `estimators.classes_[1]`
is considered as the positive class.

drop_intermediate : bool, default=False
Whether to drop some suboptimal thresholds which would not appear
on a plotted precision-recall curve. This is useful in order to
create lighter precision-recall curves.

.. versionadded:: 1.3

response_method : {'predict_proba', 'decision_function', 'auto'}, \
default='auto'
Specifies whether to use :term:`predict_proba` or
Expand Down Expand Up @@ -286,6 +294,7 @@ def from_estimator(
sample_weight=sample_weight,
name=name,
pos_label=pos_label,
drop_intermediate=drop_intermediate,
ax=ax,
**kwargs,
)
Expand All @@ -298,6 +307,7 @@ def from_predictions(
*,
sample_weight=None,
pos_label=None,
drop_intermediate=False,
name=None,
ax=None,
**kwargs,
Expand All @@ -319,6 +329,13 @@ def from_predictions(
The class considered as the positive class when computing the
precision and recall metrics.

drop_intermediate : bool, default=False
Whether to drop some suboptimal thresholds which would not appear
on a plotted precision-recall curve. This is useful in order to
create lighter precision-recall curves.

.. versionadded:: 1.3

name : str, default=None
Name for labeling curve. If `None`, name will be set to
`"Classifier"`.
Expand Down Expand Up @@ -374,7 +391,11 @@ def from_predictions(
pos_label = _check_pos_label_consistency(pos_label, y_true)

precision, recall, _ = precision_recall_curve(
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
y_true,
y_pred,
pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
)
average_precision = average_precision_score(
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
Expand Down
17 changes: 13 additions & 4 deletions sklearn/metrics/_plot/tests/test_precision_recall_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def test_precision_recall_display_validation(pyplot):

@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
def test_precision_recall_display_plotting(pyplot, constructor_name, response_method):
@pytest.mark.parametrize("drop_intermediate", [True, False])
def test_precision_recall_display_plotting(
pyplot, constructor_name, response_method, drop_intermediate
):
"""Check the overall plotting rendering."""
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
pos_label = 1
Expand All @@ -81,14 +84,20 @@ def test_precision_recall_display_plotting(pyplot, constructor_name, response_me

if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
classifier, X, y, response_method=response_method
classifier,
X,
y,
response_method=response_method,
drop_intermediate=drop_intermediate,
)
else:
display = PrecisionRecallDisplay.from_predictions(
y, y_pred, pos_label=pos_label
y, y_pred, pos_label=pos_label, drop_intermediate=drop_intermediate
)

precision, recall, _ = precision_recall_curve(y, y_pred, pos_label=pos_label)
precision, recall, _ = precision_recall_curve(
y, y_pred, pos_label=pos_label, drop_intermediate=drop_intermediate
)
average_precision = average_precision_score(y, y_pred, pos_label=pos_label)

np.testing.assert_allclose(display.precision, precision)
Expand Down
27 changes: 26 additions & 1 deletion sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,12 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
"probas_pred": ["array-like"],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
"drop_intermediate": ["boolean"],
}
)
def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight=None):
def precision_recall_curve(
y_true, probas_pred, *, pos_label=None, sample_weight=None, drop_intermediate=False
):
"""Compute precision-recall pairs for different probability thresholds.

Note: this implementation is restricted to the binary classification task.
Expand Down Expand Up @@ -864,6 +867,13 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

drop_intermediate : bool, default=False
Whether to drop some suboptimal thresholds which would not appear
on a plotted precision-recall curve. This is useful in order to create
lighter precision-recall curves.

.. versionadded:: 1.3

Returns
-------
precision : ndarray of shape (n_thresholds + 1,)
Expand Down Expand Up @@ -907,6 +917,21 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight
)

if drop_intermediate and len(fps) > 2:
# Drop thresholds corresponding to points where true positives (tps)
# do not change from the previous or subsequent point. This will keep
# only the first and last point for each tps value. All points
# with the same tps value have the same recall and thus x coordinate.
# They appear as a vertical line on the plot.
optimal_idxs = np.where(
np.concatenate(
[[True], np.logical_or(np.diff(tps[:-1]), np.diff(tps[1:])), [True]]
)
)[0]
fps = fps[optimal_idxs]
tps = tps[optimal_idxs]
thresholds = thresholds[optimal_idxs]

ps = tps + fps
# Initialize the result array with zeros to make sure that precision[ps == 0]
# does not contain uninitialized values.
Expand Down
81 changes: 62 additions & 19 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,35 +891,41 @@ def test_binary_clf_curve_zero_sample_weight(curve_func):
assert_allclose(arr_1, arr_2)


def test_precision_recall_curve():
@pytest.mark.parametrize("drop", [True, False])
def test_precision_recall_curve(drop):
y_true, _, y_score = make_prediction(binary=True)
_test_precision_recall_curve(y_true, y_score)
_test_precision_recall_curve(y_true, y_score, drop)

# Make sure the first point of the Precision-Recall on the right is:
# (p=1.0, r=class balance) on a non-balanced dataset [1:]
p, r, t = precision_recall_curve(y_true[1:], y_score[1:])
p, r, t = precision_recall_curve(y_true[1:], y_score[1:], drop_intermediate=drop)
assert r[0] == 1.0
assert p[0] == y_true[1:].mean()

# Use {-1, 1} for labels; make sure original labels aren't modified
y_true[np.where(y_true == 0)] = -1
y_true_copy = y_true.copy()
_test_precision_recall_curve(y_true, y_score)
_test_precision_recall_curve(y_true, y_score, drop)
assert_array_equal(y_true_copy, y_true)

labels = [1, 0, 0, 1]
predict_probas = [1, 2, 3, 4]
p, r, t = precision_recall_curve(labels, predict_probas)
assert_array_almost_equal(p, np.array([0.5, 0.33333333, 0.5, 1.0, 1.0]))
assert_array_almost_equal(r, np.array([1.0, 0.5, 0.5, 0.5, 0.0]))
assert_array_almost_equal(t, np.array([1, 2, 3, 4]))
p, r, t = precision_recall_curve(labels, predict_probas, drop_intermediate=drop)
if drop:
assert_allclose(p, [0.5, 0.33333333, 1.0, 1.0])
assert_allclose(r, [1.0, 0.5, 0.5, 0.0])
assert_allclose(t, [1, 2, 4])
else:
assert_allclose(p, [0.5, 0.33333333, 0.5, 1.0, 1.0])
assert_allclose(r, [1.0, 0.5, 0.5, 0.5, 0.0])
assert_allclose(t, [1, 2, 3, 4])
assert p.size == r.size
assert p.size == t.size + 1


def _test_precision_recall_curve(y_true, y_score):
def _test_precision_recall_curve(y_true, y_score, drop):
# Test Precision-Recall and area under PR curve
p, r, thresholds = precision_recall_curve(y_true, y_score)
p, r, thresholds = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
precision_recall_auc = _average_precision_slow(y_true, y_score)
assert_array_almost_equal(precision_recall_auc, 0.859, 3)
assert_array_almost_equal(
Expand All @@ -932,25 +938,28 @@ def _test_precision_recall_curve(y_true, y_score):
assert p.size == r.size
assert p.size == thresholds.size + 1
# Smoke test in the case of proba having only one value
p, r, thresholds = precision_recall_curve(y_true, np.zeros_like(y_score))
p, r, thresholds = precision_recall_curve(
y_true, np.zeros_like(y_score), drop_intermediate=drop
)
assert p.size == r.size
assert p.size == thresholds.size + 1


def test_precision_recall_curve_toydata():
@pytest.mark.parametrize("drop", [True, False])
def test_precision_recall_curve_toydata(drop):
with np.errstate(all="raise"):
# Binary classification
y_true = [0, 1]
y_score = [0, 1]
p, r, _ = precision_recall_curve(y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
auc_prc = average_precision_score(y_true, y_score)
assert_array_almost_equal(p, [0.5, 1, 1])
assert_array_almost_equal(r, [1, 1, 0])
assert_almost_equal(auc_prc, 1.0)

y_true = [0, 1]
y_score = [1, 0]
p, r, _ = precision_recall_curve(y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
auc_prc = average_precision_score(y_true, y_score)
assert_array_almost_equal(p, [0.5, 0.0, 1.0])
assert_array_almost_equal(r, [1.0, 0.0, 0.0])
Expand All @@ -961,23 +970,23 @@ def test_precision_recall_curve_toydata():

y_true = [1, 0]
y_score = [1, 1]
p, r, _ = precision_recall_curve(y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
auc_prc = average_precision_score(y_true, y_score)
assert_array_almost_equal(p, [0.5, 1])
assert_array_almost_equal(r, [1.0, 0])
assert_almost_equal(auc_prc, 0.5)

y_true = [1, 0]
y_score = [1, 0]
p, r, _ = precision_recall_curve(y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
auc_prc = average_precision_score(y_true, y_score)
assert_array_almost_equal(p, [0.5, 1, 1])
assert_array_almost_equal(r, [1, 1, 0])
assert_almost_equal(auc_prc, 1.0)

y_true = [1, 0]
y_score = [0.5, 0.5]
p, r, _ = precision_recall_curve(y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
auc_prc = average_precision_score(y_true, y_score)
assert_array_almost_equal(p, [0.5, 1])
assert_array_almost_equal(r, [1, 0.0])
Expand All @@ -986,7 +995,7 @@ def test_precision_recall_curve_toydata():
y_true = [0, 0]
y_score = [0.25, 0.75]
with pytest.warns(UserWarning, match="No positive class found in y_true"):
p, r, _ = precision_recall_curve(y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
auc_prc = average_precision_score(y_true, y_score)
assert_allclose(p, [0, 0, 1])
Expand All @@ -995,7 +1004,7 @@ def test_precision_recall_curve_toydata():

y_true = [1, 1]
y_score = [0.25, 0.75]
p, r, _ = precision_recall_curve(y_true, y_score)
p, r, _ = precision_recall_curve(y_true, y_score, drop_intermediate=drop)
assert_almost_equal(average_precision_score(y_true, y_score), 1.0)
assert_array_almost_equal(p, [1.0, 1.0, 1.0])
assert_array_almost_equal(r, [1, 0.5, 0.0])
Expand Down Expand Up @@ -1100,6 +1109,40 @@ def test_precision_recall_curve_toydata():
)


def test_precision_recall_curve_drop_intermediate():
"""Check the behaviour of the `drop_intermediate` parameter."""
y_true = [0, 0, 0, 0, 1, 1]
y_score = [0.0, 0.2, 0.5, 0.6, 0.7, 1.0]
precision, recall, thresholds = precision_recall_curve(
y_true, y_score, drop_intermediate=True
)
assert_allclose(thresholds, [0.0, 0.7, 1.0])

# Test dropping thresholds with repeating scores
y_true = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
y_score = [0.0, 0.1, 0.6, 0.6, 0.7, 0.8, 0.9, 0.6, 0.7, 0.8, 0.9, 0.9, 1.0]
precision, recall, thresholds = precision_recall_curve(
y_true, y_score, drop_intermediate=True
)
assert_allclose(thresholds, [0.0, 0.6, 0.7, 0.8, 0.9, 1.0])

# Test all false keeps only endpoints
y_true = [0, 0, 0, 0]
y_score = [0.0, 0.1, 0.2, 0.3]
precision, recall, thresholds = precision_recall_curve(
y_true, y_score, drop_intermediate=True
)
assert_allclose(thresholds, [0.0, 0.3])

# Test all true keeps all thresholds
y_true = [1, 1, 1, 1]
y_score = [0.0, 0.1, 0.2, 0.3]
precision, recall, thresholds = precision_recall_curve(
y_true, y_score, drop_intermediate=True
)
assert_allclose(thresholds, [0.0, 0.1, 0.2, 0.3])


def test_average_precision_constant_values():
# Check the average_precision_score of a constant predictor is
# the TPR
Expand Down