Skip to content

[MRG] Plotting API starting with ROC curve #14357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
aa753b1
WIP
thomasjpfan Jul 5, 2019
d5ba421
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 10, 2019
a0e4723
ENH Adds plot_roc_curve
thomasjpfan Jul 10, 2019
8ae4c70
DOC Adds docs
thomasjpfan Jul 11, 2019
eaac39c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 11, 2019
763d723
BUG Fix
thomasjpfan Jul 11, 2019
2330e06
DOC Adds label to parameters
thomasjpfan Jul 11, 2019
4e33b28
DOC Adds label as a parameter
thomasjpfan Jul 11, 2019
40381ae
API Update with kwargs
thomasjpfan Jul 11, 2019
5f83a80
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 14, 2019
663fc58
BLD Add tests to setup
thomasjpfan Jul 14, 2019
dea2b1b
DOC Update docs
thomasjpfan Jul 14, 2019
3e65c77
DOC Update ordering
thomasjpfan Jul 14, 2019
b4a8d0e
ENH Adds auc to labels
thomasjpfan Jul 17, 2019
eba2453
CLN Updates example with plotting api
thomasjpfan Jul 17, 2019
74b8d7b
TST Updates test
thomasjpfan Jul 17, 2019
272845f
CLN Moves name to plot
thomasjpfan Jul 17, 2019
c51da17
TST Adds more tests
thomasjpfan Jul 18, 2019
5fc6f29
DOC Removes line_kw parameter docstring
thomasjpfan Jul 22, 2019
8832788
DOC Fix docs
thomasjpfan Jul 22, 2019
56ff821
CLN Removes unused import
thomasjpfan Jul 22, 2019
43f1787
CLN Uses kwargs
thomasjpfan Jul 22, 2019
bdf782a
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 23, 2019
6cc8712
CLN Does computation in plot_* function
thomasjpfan Jul 23, 2019
7051e5e
DOC Adds type
thomasjpfan Jul 23, 2019
d461543
CLN Address comments
thomasjpfan Jul 29, 2019
bda4435
BUG Spelling
thomasjpfan Jul 29, 2019
af4209c
TST Assert message
thomasjpfan Jul 29, 2019
014851e
CLN Moves plot to _plot
thomasjpfan Jul 29, 2019
50fb1cd
WIP
thomasjpfan Jul 29, 2019
e620d8c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Jul 30, 2019
4d8534c
DOC Adds user guide
thomasjpfan Jul 30, 2019
46cc0c5
BLD Fix
thomasjpfan Jul 30, 2019
6771d8d
BLD Build docs [doc build]
thomasjpfan Jul 31, 2019
dd7f3cf
CLN Address comments
thomasjpfan Jul 31, 2019
e668471
CLN Address comments
thomasjpfan Aug 1, 2019
94f29e3
STY Update styling
thomasjpfan Aug 1, 2019
9546755
STY Update styling
thomasjpfan Aug 1, 2019
71b19a3
DOC Adds note about parameters stored as attributes
thomasjpfan Aug 1, 2019
f7f2e2e
CLN Updates to roc_auc
thomasjpfan Aug 5, 2019
e2f42fa
ENH Adds check for response_method
thomasjpfan Aug 5, 2019
30925a8
DOC Adds parameters
thomasjpfan Aug 5, 2019
2e8db1f
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan Aug 5, 2019
fab9d07
CLN Renames to display
thomasjpfan Aug 5, 2019
759d1b4
CLN Address comments
thomasjpfan Aug 7, 2019
140287d
improve error message
qinhanmin2014 Aug 8, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions doc/developers/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1647,3 +1647,54 @@ make this task easier and faster (in no particular order).
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely
useful to see every occurrence of a pattern (e.g. a function call or a
variable) in the code base.


.. _plotting_api:

Plotting API
============

Scikit-learn defines a simple API for creating visualizations for machine
learning. The key features of this API is to run calculations once and to have
the flexibility to adjust the visualizations after the fact. This logic is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to actually have a very simple example that illustrates this feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link to example of user guide now?

encapsulated into a display object where the computed data is stored and
the plotting is done in a `plot` method. The display object's `__init__`
method contains only the data needed to create the visualization. The `plot`
method takes in parameters that only have to do with visualization, such as a
matplotlib axes. The `plot` method will store the matplotlib artists as
attributes allowing for style adjustments through the display object. A
`plot_*` helper function accepts parameters to do the computation and the
parameters used for plotting. After the helper function creates the display
object with the computed values, it calls the display's plot method. Note
that the `plot` method defines attributes related to matplotlib, such as the
line artist. This allows for customizations after calling the `plot` method.

For example, the `RocCurveDisplay` defines the following methods and
attributes:

.. code-block:: python

class RocCurveDisplay:
def __init__(self, fpr, tpr, roc_auc, estimator_name):
...
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name

def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_

def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
# do computation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read this comment as referring to the line below, which is not the intent.

Maybe

# ...
# Do computation
# ...

viz = RocCurveDisplay(fpr, tpr, roc_auc,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)
```

Read more in the :ref:`User Guide <visualizations>`.
20 changes: 20 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,26 @@ See the :ref:`metrics` section of the user guide for further details.
metrics.pairwise_distances_chunked


Plotting
--------

See the :ref:`visualizations` section of the user guide for further details.

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: function.rst

metrics.plot_roc_curve

.. autosummary::
:toctree: generated/
:template: class.rst

metrics.RocCurveDisplay


.. _mixture_ref:

:mod:`sklearn.mixture`: Gaussian Mixture Models
Expand Down
1 change: 1 addition & 0 deletions doc/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ User Guide
unsupervised_learning.rst
model_selection.rst
inspection.rst
visualizations.rst
data_transforms.rst
Dataset loading utilities <datasets/index.rst>
modules/computing.rst
83 changes: 83 additions & 0 deletions doc/visualizations.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
.. include:: includes/big_toc_css.rst

.. _visualizations:

==============
Visualizations
==============

Scikit-learn defines a simple API for creating visualizations for machine
learning. The key feature of this API is to allow for quick plotting and
visual adjustments without recalculation. In the following example, we plot a
ROC curve for a fitted support vector machine:

.. code-block:: python

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

svc_disp = plot_roc_curve(svc, X_test, y_test)

.. figure:: ../auto_examples/images/sphx_glr_plot_roc_curve_visualization_api_001.png
:target: ../auto_examples/plot_roc_curve_visualization_api.html
:align: center
:scale: 75%

The returned `svc_disp` object allows us to continue using the already computed
ROC curve for SVC in future plots. In this case, the `svc_disp` is a
:class:`~sklearn.metrics.RocCurveDisplay` that stores the computed values as
attributes called `roc_auc`, `fpr`, and `tpr`. Next, we train a random forest
classifier and plot the previously computed roc curve again by using the `plot`
method of the `Display` object.

.. code-block:: python

import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier

rfc = RandomForestClassifier(random_state=42)
rfc.fit(X_train, y_train)

ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)

.. figure:: ../auto_examples/images/sphx_glr_plot_roc_curve_visualization_api_002.png
:target: ../auto_examples/plot_roc_curve_visualization_api.html
:align: center
:scale: 75%

Notice that we pass `alpha=0.8` to the plot functions to adjust the alpha
values of the curves.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py`

Available Plotting Utilities
============================

Functions
---------

.. currentmodule:: sklearn

.. autosummary::

metrics.plot_roc_curve


Display Objects
---------------

.. currentmodule:: sklearn

.. autosummary::

metrics.RocCurveDisplay
52 changes: 24 additions & 28 deletions examples/model_selection/plot_roc_crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import auc
from sklearn.metrics import plot_roc_curve
from sklearn.model_selection import StratifiedKFold

# #############################################################################
Expand Down Expand Up @@ -65,40 +66,35 @@
aucs = []
mean_fpr = np.linspace(0, 1, 100)

i = 0
for train, test in cv.split(X, y):
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
# Compute ROC curve and area the curve
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
tprs.append(interp(mean_fpr, fpr, tpr))
tprs[-1][0] = 0.0
roc_auc = auc(fpr, tpr)
aucs.append(roc_auc)
plt.plot(fpr, tpr, lw=1, alpha=0.3,
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))

i += 1
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8)
fig, ax = plt.subplots()
for i, (train, test) in enumerate(cv.split(X, y)):
classifier.fit(X[train], y[train])
viz = plot_roc_curve(classifier, X[test], y[test],
name='ROC fold {}'.format(i),
alpha=0.3, lw=1, ax=ax)
interp_tpr = interp(mean_fpr, viz.fpr, viz.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.roc_auc)

ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8)

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
plt.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8)
ax.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8)

std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.')

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.')

ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
title="Receiver operating characteristic example")
ax.legend(loc="lower right")
plt.show()
55 changes: 55 additions & 0 deletions examples/plot_roc_curve_visualization_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
================================
ROC Curve with Visualization API
================================
Scikit-learn defines a simple API for creating visualizations for machine
learning. The key features of this API is to allow for quick plotting and
visual adjustments without recalculation. In this example, we will demonstrate
how to use the visualization API by comparing ROC curves.
"""
print(__doc__)

##############################################################################
# Load Data and Train a SVC
# -------------------------
# First, we load the wine dataset and convert it to a binary classification
# problem. Then, we train a support vector classifier on a training dataset.
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import plot_roc_curve
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

X, y = load_wine(return_X_y=True)
y = y == 2

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

##############################################################################
# Plotting the ROC Curve
# ----------------------
# Next, we plot the ROC curve with a single call to
# :func:`sklearn.metrics.plot_roc_curve`. The returned `svc_disp` object allows
# us to continue using the already computed ROC curve for the SVC in future
# plots.
svc_disp = plot_roc_curve(svc, X_test, y_test)
plt.show()

##############################################################################
# Training a Random Forest and Plotting the ROC Curve
# --------------------------------------------------------
# We train a random forest classifier and create a plot comparing it to the SVC
# ROC curve. Notice how `svc_disp` uses
# :func:`~sklearn.metrics.RocCurveDisplay.plot` to plot the SVC ROC curve
# without recomputing the values of the roc curve itself. Futhermore, we
# pass `alpha=0.8` to the plot functions to adjust the alpha values of the
# curves.
rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)
plt.show()
6 changes: 6 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
from .scorer import SCORERS
from .scorer import get_scorer

from ._plot.roc_curve import plot_roc_curve
from ._plot.roc_curve import RocCurveDisplay


__all__ = [
'accuracy_score',
'adjusted_mutual_info_score',
Expand Down Expand Up @@ -125,11 +129,13 @@
'pairwise_distances_argmin_min',
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_roc_curve',
'precision_recall_curve',
'precision_recall_fscore_support',
'precision_score',
'r2_score',
'recall_score',
'RocCurveDisplay',
'roc_auc_score',
'roc_curve',
'SCORERS',
Expand Down
Empty file.
Loading