Skip to content

DOC: Add from_predictions example and other details to visualizations.rst #30825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
May 26, 2025
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
21e0578
Added from_predictions
DeaMariaLeon Feb 13, 2025
c71d9ba
version to propose
DeaMariaLeon Feb 13, 2025
36f31e8
removed error
DeaMariaLeon Feb 13, 2025
0a13cd0
after first review
DeaMariaLeon Feb 14, 2025
3996689
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon Feb 14, 2025
96638d7
fix link
DeaMariaLeon Feb 14, 2025
4e9eb83
added arg names, removed unused blank lines
DeaMariaLeon Feb 21, 2025
001546a
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon Feb 21, 2025
6132fad
Correcting Display explanation
DeaMariaLeon Mar 7, 2025
1c88ab1
after feedback
DeaMariaLeon Mar 10, 2025
2a25249
Changed classifier to LogisticRegression
DeaMariaLeon Mar 10, 2025
7c8182f
changed Display object phrase
DeaMariaLeon Mar 10, 2025
394a567
fix typo
DeaMariaLeon Mar 10, 2025
3464555
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon Mar 10, 2025
21f398a
after another round of feedback
DeaMariaLeon Mar 12, 2025
649da64
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon Mar 12, 2025
8ef6f41
corrected random_state to match others
DeaMariaLeon Mar 12, 2025
4b2d6c7
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon May 12, 2025
c4a492f
Trying to add all the feedbacks
DeaMariaLeon May 12, 2025
b89de07
fix extra space in plot
DeaMariaLeon May 12, 2025
a37e898
After feedback again
DeaMariaLeon May 12, 2025
c9d6c0e
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon May 12, 2025
e1ca7cb
next feedback fix
DeaMariaLeon May 16, 2025
42fd0d9
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon May 16, 2025
dff13b2
removed an extra -and-
DeaMariaLeon May 16, 2025
103cbf5
Comply with the lastest feedback
DeaMariaLeon May 22, 2025
4c6c5f2
Merge remote-tracking branch 'upstream/main' into from_pred
DeaMariaLeon May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 71 additions & 21 deletions doc/visualizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,86 @@ Scikit-learn defines a simple API for creating visualizations for machine
learning. The key feature of this API is to allow for quick plotting and
visual adjustments without recalculation. We provide `Display` classes that
expose two methods for creating plots: `from_estimator` and
`from_predictions`. The `from_estimator` method will take a fitted estimator
and some data (`X` and `y`) and create a `Display` object. Sometimes, we would
like to only compute the predictions once and one should use `from_predictions`
instead. In the following example, we plot a ROC curve for a fitted support
vector machine:
`from_predictions`.

The `from_estimator` method generates a `Display` object from a fitted estimator,
input data (`X`, `y`), and a plot.
The `from_predictions` method creates a `Display` object from true and predicted
values (`y_test`, `y_pred`), and a plot.

Using `from_predictions` avoids having to recompute predictions,
but the user needs to take care that the prediction values passed correspond
to the `pos_label`. For :term:`predict_proba`, select the column corresponding
to the `pos_label` class while for :term:`decision_function`, revert the score
(i.e. multiply by -1) if `pos_label` is not the last class in the
`classes_` attribute of your estimator.

The `Display` object stores the computed values (e.g., metric values or
feature importance) required for plotting with Matplotlib. These values are the
results derived from the raw predictions passed to `from_predictions`, or
an estimator and `X` passed to `from_estimator`.

Display objects have a plot method that creates a matplotlib plot once the display
object has been initialized (note that we recommend that display objects are created
via `from_estimator` or `from_predictions` instead of initialized directly).
The plot method allows adding to an existing plot by passing the existing plots
:class:`matplotlib.axes.Axes` to the `ax` parameter.

In the following example, we plot a ROC curve for a fitted Logistic Regression
model `from_estimator`:

.. plot::
:context: close-figs
:align: center

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import RocCurveDisplay
from sklearn.datasets import load_wine
from sklearn.datasets import load_iris

X, y = load_wine(return_X_y=True)
X, y = load_iris(return_X_y=True)
y = y == 2 # make binary
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=.8, random_state=42
)
clf = LogisticRegression(random_state=42, C=.01)
clf.fit(X_train, y_train)

svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
clf_disp = RocCurveDisplay.from_estimator(clf, X_test, y_test)

The returned `svc_disp` object allows us to continue using the already computed
ROC curve for SVC in future plots. In this case, the `svc_disp` is a
:class:`~sklearn.metrics.RocCurveDisplay` that stores the computed values as
attributes called `roc_auc`, `fpr`, and `tpr`. Be aware that we could get
the predictions from the support vector machine and then use `from_predictions`
instead of `from_estimator`. Next, we train a random forest classifier and plot
the previously computed ROC curve again by using the `plot` method of the
`Display` object.
If you already have the prediction values, you could instead use
`from_predictions` to do the same thing (and save on compute):


.. plot::
:context: close-figs
:align: center

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import RocCurveDisplay
from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True)
y = y == 2 # make binary
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=.8, random_state=42
)
clf = LogisticRegression(random_state=42, C=.01)
clf.fit(X_train, y_train)

# select the probability of the class that we considered to be the positive label
y_pred = clf.predict_proba(X_test)[:, 1]

clf_disp = RocCurveDisplay.from_predictions(y_test, y_pred)


The returned `clf_disp` object allows us to add another curve to the already computed
ROC curve. In this case, the `clf_disp` is a :class:`~sklearn.metrics.RocCurveDisplay`
that stores the computed values as attributes called `roc_auc`, `fpr`, and `tpr`.

Next, we train a random forest classifier and plot the previously computed ROC curve
again by using the `plot` method of the `Display` object.

.. plot::
:context: close-figs
Expand All @@ -52,11 +101,12 @@ the previously computed ROC curve again by using the `plot` method of the

ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)
clf_disp.plot(ax=ax, alpha=0.8)

Notice that we pass `alpha=0.8` to the plot functions to adjust the alpha
values of the curves.


.. rubric:: Examples

* :ref:`sphx_glr_auto_examples_miscellaneous_plot_roc_curve_visualization_api.py`
Expand Down