Skip to content

DEP expose y_score instead of y_pred RocCurveDisplay.from_predictions #29865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.metrics/29865.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- In :meth:`sklearn.metrics.RocCurveDisplay.from_predictions`,
the argument `y_pred` has been renamed to `y_score` to better reflect its purpose.
`y_pred` will be removed in 1.9.
By :user:`Bagus Tris Atmaja <bagustris>` in
30 changes: 15 additions & 15 deletions examples/miscellaneous/plot_outlier_detection_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def fit_predict(estimator, X):
tic = perf_counter()
if estimator[-1].__class__.__name__ == "LocalOutlierFactor":
estimator.fit(X)
y_pred = estimator[-1].negative_outlier_factor_
y_score = estimator[-1].negative_outlier_factor_
else: # "IsolationForest"
y_pred = estimator.fit(X).decision_function(X)
y_score = estimator.fit(X).decision_function(X)
toc = perf_counter()
print(f"Duration for {model_name}: {toc - tic:.2f} s")
return y_pred
return y_score


# %%
Expand Down Expand Up @@ -138,7 +138,7 @@ def fit_predict(estimator, X):

# %%
y_true = {}
y_pred = {"LOF": {}, "IForest": {}}
y_score = {"LOF": {}, "IForest": {}}
model_names = ["LOF", "IForest"]
cat_columns = ["protocol_type", "service", "flag"]

Expand All @@ -150,7 +150,7 @@ def fit_predict(estimator, X):
lof_kw={"n_neighbors": int(n_samples * anomaly_frac)},
iforest_kw={"random_state": 42},
)
y_pred[model_name]["KDDCup99 - SA"] = fit_predict(model, X)
y_score[model_name]["KDDCup99 - SA"] = fit_predict(model, X)

# %%
# Forest covertypes dataset
Expand Down Expand Up @@ -185,7 +185,7 @@ def fit_predict(estimator, X):
lof_kw={"n_neighbors": int(n_samples * anomaly_frac)},
iforest_kw={"random_state": 42},
)
y_pred[model_name]["forestcover"] = fit_predict(model, X)
y_score[model_name]["forestcover"] = fit_predict(model, X)

# %%
# Ames Housing dataset
Expand Down Expand Up @@ -242,7 +242,7 @@ def fit_predict(estimator, X):
lof_kw={"n_neighbors": int(n_samples * anomaly_frac)},
iforest_kw={"random_state": 42},
)
y_pred[model_name]["ames_housing"] = fit_predict(model, X)
y_score[model_name]["ames_housing"] = fit_predict(model, X)

# %%
# Cardiotocography dataset
Expand Down Expand Up @@ -271,7 +271,7 @@ def fit_predict(estimator, X):
lof_kw={"n_neighbors": int(n_samples * anomaly_frac)},
iforest_kw={"random_state": 42},
)
y_pred[model_name]["cardiotocography"] = fit_predict(model, X)
y_score[model_name]["cardiotocography"] = fit_predict(model, X)

# %%
# Plot and interpret results
Expand Down Expand Up @@ -299,7 +299,7 @@ def fit_predict(estimator, X):
for model_idx, model_name in enumerate(model_names):
display = RocCurveDisplay.from_predictions(
y_true[dataset_name],
y_pred[model_name][dataset_name],
y_score[model_name][dataset_name],
pos_label=pos_label,
name=model_name,
ax=ax,
Expand Down Expand Up @@ -346,10 +346,10 @@ def fit_predict(estimator, X):
for model_idx, (linestyle, n_neighbors) in enumerate(zip(linestyles, n_neighbors_list)):
model.set_params(localoutlierfactor__n_neighbors=n_neighbors)
model.fit(X)
y_pred = model[-1].negative_outlier_factor_
y_score = model[-1].negative_outlier_factor_
display = RocCurveDisplay.from_predictions(
y,
y_pred,
y_score,
pos_label=pos_label,
name=f"n_neighbors = {n_neighbors}",
ax=ax,
Expand Down Expand Up @@ -386,10 +386,10 @@ def fit_predict(estimator, X):
):
model = make_pipeline(preprocessor, lof)
model.fit(X)
y_pred = model[-1].negative_outlier_factor_
y_score = model[-1].negative_outlier_factor_
display = RocCurveDisplay.from_predictions(
y,
y_pred,
y_score,
pos_label=pos_label,
name=str(preprocessor).split("(")[0],
ax=ax,
Expand Down Expand Up @@ -438,10 +438,10 @@ def fit_predict(estimator, X):
):
model = make_pipeline(preprocessor, lof)
model.fit(X)
y_pred = model[-1].negative_outlier_factor_
y_score = model[-1].negative_outlier_factor_
display = RocCurveDisplay.from_predictions(
y,
y_pred,
y_score,
pos_label=pos_label,
name=str(preprocessor).split("(")[0],
ax=ax,
Expand Down
56 changes: 44 additions & 12 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import warnings

from ...utils._plotting import (
_BinaryClassifierCurveDisplayMixin,
_despine,
Expand Down Expand Up @@ -71,9 +73,9 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin):
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from sklearn import metrics
>>> y = np.array([0, 0, 1, 1])
>>> pred = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y, pred)
>>> y_true = np.array([0, 0, 1, 1])
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score)
>>> roc_auc = metrics.auc(fpr, tpr)
>>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
... estimator_name='example estimator')
Expand Down Expand Up @@ -299,7 +301,7 @@ def from_estimator(
<...>
>>> plt.show()
"""
y_pred, pos_label, name = cls._validate_and_get_response_values(
y_score, pos_label, name = cls._validate_and_get_response_values(
estimator,
X,
y,
Expand All @@ -310,7 +312,7 @@ def from_estimator(

return cls.from_predictions(
y_true=y,
y_pred=y_pred,
y_score=y_score,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
name=name,
Expand All @@ -326,7 +328,7 @@ def from_estimator(
def from_predictions(
cls,
y_true,
y_pred,
y_score=None,
*,
sample_weight=None,
drop_intermediate=True,
Expand All @@ -336,6 +338,7 @@ def from_predictions(
plot_chance_level=False,
chance_level_kw=None,
despine=False,
y_pred="deprecated",
**kwargs,
):
"""Plot ROC curve given the true and predicted values.
Expand All @@ -349,11 +352,14 @@ def from_predictions(
y_true : array-like of shape (n_samples,)
True labels.

y_pred : array-like of shape (n_samples,)
y_score : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by “decision_function” on some classifiers).

.. versionadded:: 1.7
`y_pred` has been renamed to `y_score`.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

Expand Down Expand Up @@ -391,6 +397,15 @@ def from_predictions(

.. versionadded:: 1.6

y_pred : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by “decision_function” on some classifiers).

.. deprecated:: 1.7
`y_pred` is deprecated and will be removed in 1.9. Use
`y_score` instead.

**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.

Expand All @@ -417,19 +432,36 @@ def from_predictions(
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, random_state=0)
>>> clf = SVC(random_state=0).fit(X_train, y_train)
>>> y_pred = clf.decision_function(X_test)
>>> RocCurveDisplay.from_predictions(
... y_test, y_pred)
>>> y_score = clf.decision_function(X_test)
>>> RocCurveDisplay.from_predictions(y_test, y_score)
<...>
>>> plt.show()
"""
# TODO(1.9): remove after the end of the deprecation period of `y_pred`
if y_score is not None and not (
isinstance(y_pred, str) and y_pred == "deprecated"
):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to had new tests to be sure that we properly raise this error.
Basically this is the same thing for the code non-covered in the line below.

Copy link
Contributor Author

@bagustris bagustris Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added unit tests namely test_y_score_and_y_pred_deprecation and test_y_pred_deprecation_warning in test_roc_curve_display.py. I also changed y_pred to y_score in that test file following the proposed changes in this PR.

"`y_pred` and `y_score` cannot be both specified. Please use `y_score`"
" only as `y_pred` is deprecated in 1.7 and will be removed in 1.9."
)
if not (isinstance(y_pred, str) and y_pred == "deprecated"):
warnings.warn(
(
"y_pred is deprecated in 1.7 and will be removed in 1.9. "
"Please use `y_score` instead."
),
FutureWarning,
)
y_score = y_pred

pos_label_validated, name = cls._validate_from_predictions_params(
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name
)

fpr, tpr, _ = roc_curve(
y_true,
y_pred,
y_score,
pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
Expand Down
54 changes: 42 additions & 12 deletions sklearn/metrics/_plot/tests/test_roc_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_roc_curve_display_plotting(
lr = LogisticRegression()
lr.fit(X, y)

y_pred = getattr(lr, response_method)(X)
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
y_score = getattr(lr, response_method)(X)
y_score = y_score if y_score.ndim == 1 else y_score[:, 1]

if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
Expand All @@ -84,7 +84,7 @@ def test_roc_curve_display_plotting(
else:
display = RocCurveDisplay.from_predictions(
y,
y_pred,
y_score,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
Expand All @@ -93,7 +93,7 @@ def test_roc_curve_display_plotting(

fpr, tpr, _ = roc_curve(
y,
y_pred,
y_score,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
Expand Down Expand Up @@ -155,8 +155,8 @@ def test_roc_curve_chance_level_line(
lr = LogisticRegression()
lr.fit(X, y)

y_pred = getattr(lr, "predict_proba")(X)
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
y_score = getattr(lr, "predict_proba")(X)
y_score = y_score if y_score.ndim == 1 else y_score[:, 1]

if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
Expand All @@ -171,7 +171,7 @@ def test_roc_curve_chance_level_line(
else:
display = RocCurveDisplay.from_predictions(
y,
y_pred,
y_score,
label=label,
alpha=0.8,
plot_chance_level=plot_chance_level,
Expand Down Expand Up @@ -306,11 +306,11 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name):
# are betrayed by the class imbalance
assert classifier.classes_.tolist() == ["cancer", "not cancer"]

y_pred = getattr(classifier, response_method)(X_test)
y_score = getattr(classifier, response_method)(X_test)
# we select the corresponding probability columns or reverse the decision
# function otherwise
y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0]
y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
y_score_cancer = -1 * y_score if y_score.ndim == 1 else y_score[:, 0]
y_score_not_cancer = y_score if y_score.ndim == 1 else y_score[:, 1]

if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
Expand All @@ -323,7 +323,7 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name):
else:
display = RocCurveDisplay.from_predictions(
y_test,
y_pred_cancer,
y_score_cancer,
pos_label="cancer",
)

Expand All @@ -343,14 +343,44 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name):
else:
display = RocCurveDisplay.from_predictions(
y_test,
y_pred_not_cancer,
y_score_not_cancer,
pos_label="not cancer",
)

assert display.roc_auc == pytest.approx(roc_auc_limit)
assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit)


# TODO(1.9): remove
def test_y_score_and_y_pred_specified_error():
"""Check that an error is raised when both y_score and y_pred are specified."""
y_true = np.array([0, 1, 1, 0])
y_score = np.array([0.1, 0.4, 0.35, 0.8])
y_pred = np.array([0.2, 0.3, 0.5, 0.1])

with pytest.raises(
ValueError, match="`y_pred` and `y_score` cannot be both specified"
):
RocCurveDisplay.from_predictions(y_true, y_score=y_score, y_pred=y_pred)


# TODO(1.9): remove
def test_y_pred_deprecation_warning(pyplot):
"""Check that a warning is raised when y_pred is specified."""
y_true = np.array([0, 1, 1, 0])
y_score = np.array([0.1, 0.4, 0.35, 0.8])

with pytest.warns(FutureWarning, match="y_pred is deprecated in 1.7"):
display_y_pred = RocCurveDisplay.from_predictions(y_true, y_pred=y_score)

assert_allclose(display_y_pred.fpr, [0, 0.5, 0.5, 1])
assert_allclose(display_y_pred.tpr, [0, 0, 1, 1])

display_y_score = RocCurveDisplay.from_predictions(y_true, y_score)
assert_allclose(display_y_score.fpr, [0, 0.5, 0.5, 1])
assert_allclose(display_y_score.tpr, [0, 0, 1, 1])


@pytest.mark.parametrize("despine", [True, False])
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_plot_roc_curve_despine(pyplot, data_binary, despine, constructor_name):
Expand Down
12 changes: 6 additions & 6 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def auc(x, y):
--------
>>> import numpy as np
>>> from sklearn import metrics
>>> y = np.array([1, 1, 2, 2])
>>> pred = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=2)
>>> y_true = np.array([1, 1, 2, 2])
>>> y_score = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score, pos_label=2)
>>> metrics.auc(fpr, tpr)
0.75
"""
Expand Down Expand Up @@ -604,10 +604,10 @@ class scores must correspond to the order of ``labels``,
>>> clf = MultiOutputClassifier(clf).fit(X, y)
>>> # get a list of n_output containing probability arrays of shape
>>> # (n_samples, n_classes)
>>> y_pred = clf.predict_proba(X)
>>> y_score = clf.predict_proba(X)
>>> # extract the positive columns for each output
>>> y_pred = np.transpose([pred[:, 1] for pred in y_pred])
>>> roc_auc_score(y, y_pred, average=None)
>>> y_score = np.transpose([score[:, 1] for score in y_score])
>>> roc_auc_score(y, y_score, average=None)
array([0.82..., 0.86..., 0.94..., 0.85... , 0.94...])
>>> from sklearn.linear_model import RidgeClassifierCV
>>> clf = RidgeClassifierCV().fit(X, y)
Expand Down