-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Plotting API starting with ROC curve #14357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
aa753b1
WIP
thomasjpfan d5ba421
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan a0e4723
ENH Adds plot_roc_curve
thomasjpfan 8ae4c70
DOC Adds docs
thomasjpfan eaac39c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan 763d723
BUG Fix
thomasjpfan 2330e06
DOC Adds label to parameters
thomasjpfan 4e33b28
DOC Adds label as a parameter
thomasjpfan 40381ae
API Update with kwargs
thomasjpfan 5f83a80
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan 663fc58
BLD Add tests to setup
thomasjpfan dea2b1b
DOC Update docs
thomasjpfan 3e65c77
DOC Update ordering
thomasjpfan b4a8d0e
ENH Adds auc to labels
thomasjpfan eba2453
CLN Updates example with plotting api
thomasjpfan 74b8d7b
TST Updates test
thomasjpfan 272845f
CLN Moves name to plot
thomasjpfan c51da17
TST Adds more tests
thomasjpfan 5fc6f29
DOC Removes line_kw parameter docstring
thomasjpfan 8832788
DOC Fix docs
thomasjpfan 56ff821
CLN Removes unused import
thomasjpfan 43f1787
CLN Uses kwargs
thomasjpfan bdf782a
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan 6cc8712
CLN Does computation in plot_* function
thomasjpfan 7051e5e
DOC Adds type
thomasjpfan d461543
CLN Address comments
thomasjpfan bda4435
BUG Spelling
thomasjpfan af4209c
TST Assert message
thomasjpfan 014851e
CLN Moves plot to _plot
thomasjpfan 50fb1cd
WIP
thomasjpfan e620d8c
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan 4d8534c
DOC Adds user guide
thomasjpfan 46cc0c5
BLD Fix
thomasjpfan 6771d8d
BLD Build docs [doc build]
thomasjpfan dd7f3cf
CLN Address comments
thomasjpfan e668471
CLN Address comments
thomasjpfan 94f29e3
STY Update styling
thomasjpfan 9546755
STY Update styling
thomasjpfan 71b19a3
DOC Adds note about parameters stored as attributes
thomasjpfan f7f2e2e
CLN Updates to roc_auc
thomasjpfan e2f42fa
ENH Adds check for response_method
thomasjpfan 30925a8
DOC Adds parameters
thomasjpfan 2e8db1f
Merge remote-tracking branch 'upstream/master' into plotting_api
thomasjpfan fab9d07
CLN Renames to display
thomasjpfan 759d1b4
CLN Address comments
thomasjpfan 140287d
improve error message
qinhanmin2014 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1647,3 +1647,54 @@ make this task easier and faster (in no particular order). | |
<https://git-scm.com/docs/git-grep#_examples>`_) is also extremely | ||
useful to see every occurrence of a pattern (e.g. a function call or a | ||
variable) in the code base. | ||
|
||
|
||
.. _plotting_api: | ||
|
||
Plotting API | ||
============ | ||
|
||
Scikit-learn defines a simple API for creating visualizations for machine | ||
learning. The key features of this API is to run calculations once and to have | ||
the flexibility to adjust the visualizations after the fact. This logic is | ||
encapsulated into a display object where the computed data is stored and | ||
the plotting is done in a `plot` method. The display object's `__init__` | ||
method contains only the data needed to create the visualization. The `plot` | ||
method takes in parameters that only have to do with visualization, such as a | ||
matplotlib axes. The `plot` method will store the matplotlib artists as | ||
attributes allowing for style adjustments through the display object. A | ||
`plot_*` helper function accepts parameters to do the computation and the | ||
parameters used for plotting. After the helper function creates the display | ||
object with the computed values, it calls the display's plot method. Note | ||
that the `plot` method defines attributes related to matplotlib, such as the | ||
line artist. This allows for customizations after calling the `plot` method. | ||
|
||
For example, the `RocCurveDisplay` defines the following methods and | ||
attributes: | ||
|
||
.. code-block:: python | ||
|
||
class RocCurveDisplay: | ||
def __init__(self, fpr, tpr, roc_auc, estimator_name): | ||
... | ||
self.fpr = fpr | ||
self.tpr = tpr | ||
self.roc_auc = roc_auc | ||
self.estimator_name = estimator_name | ||
|
||
def plot(self, ax=None, name=None, **kwargs): | ||
... | ||
self.line_ = ... | ||
self.ax_ = ax | ||
self.figure_ = ax.figure_ | ||
|
||
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, | ||
drop_intermediate=True, response_method="auto", | ||
name=None, ax=None, **kwargs): | ||
# do computation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I read this comment as referring to the line below, which is not the intent. Maybe # ...
# Do computation
# ... |
||
viz = RocCurveDisplay(fpr, tpr, roc_auc, | ||
estimator.__class__.__name__) | ||
return viz.plot(ax=ax, name=name, **kwargs) | ||
``` | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
.. include:: includes/big_toc_css.rst | ||
|
||
.. _visualizations: | ||
|
||
============== | ||
Visualizations | ||
============== | ||
|
||
Scikit-learn defines a simple API for creating visualizations for machine | ||
learning. The key feature of this API is to allow for quick plotting and | ||
visual adjustments without recalculation. In the following example, we plot a | ||
ROC curve for a fitted support vector machine: | ||
|
||
.. code-block:: python | ||
|
||
from sklearn.model_selection import train_test_split | ||
from sklearn.svm import SVC | ||
from sklearn.metrics import plot_roc_curve | ||
from sklearn.datasets import load_wine | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) | ||
svc = SVC(random_state=42) | ||
svc.fit(X_train, y_train) | ||
|
||
svc_disp = plot_roc_curve(svc, X_test, y_test) | ||
|
||
.. figure:: ../auto_examples/images/sphx_glr_plot_roc_curve_visualization_api_001.png | ||
:target: ../auto_examples/plot_roc_curve_visualization_api.html | ||
:align: center | ||
:scale: 75% | ||
|
||
The returned `svc_disp` object allows us to continue using the already computed | ||
ROC curve for SVC in future plots. In this case, the `svc_disp` is a | ||
:class:`~sklearn.metrics.RocCurveDisplay` that stores the computed values as | ||
attributes called `roc_auc`, `fpr`, and `tpr`. Next, we train a random forest | ||
classifier and plot the previously computed roc curve again by using the `plot` | ||
method of the `Display` object. | ||
|
||
.. code-block:: python | ||
|
||
import matplotlib.pyplot as plt | ||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
rfc = RandomForestClassifier(random_state=42) | ||
rfc.fit(X_train, y_train) | ||
|
||
ax = plt.gca() | ||
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8) | ||
svc_disp.plot(ax=ax, alpha=0.8) | ||
|
||
.. figure:: ../auto_examples/images/sphx_glr_plot_roc_curve_visualization_api_002.png | ||
:target: ../auto_examples/plot_roc_curve_visualization_api.html | ||
:align: center | ||
:scale: 75% | ||
|
||
Notice that we pass `alpha=0.8` to the plot functions to adjust the alpha | ||
values of the curves. | ||
|
||
.. topic:: Examples: | ||
|
||
* :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py` | ||
|
||
Available Plotting Utilities | ||
============================ | ||
|
||
Functions | ||
--------- | ||
|
||
.. currentmodule:: sklearn | ||
|
||
.. autosummary:: | ||
|
||
metrics.plot_roc_curve | ||
|
||
|
||
Display Objects | ||
--------------- | ||
|
||
.. currentmodule:: sklearn | ||
|
||
.. autosummary:: | ||
|
||
metrics.RocCurveDisplay |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
================================ | ||
ROC Curve with Visualization API | ||
================================ | ||
Scikit-learn defines a simple API for creating visualizations for machine | ||
learning. The key features of this API is to allow for quick plotting and | ||
visual adjustments without recalculation. In this example, we will demonstrate | ||
how to use the visualization API by comparing ROC curves. | ||
""" | ||
print(__doc__) | ||
|
||
############################################################################## | ||
# Load Data and Train a SVC | ||
# ------------------------- | ||
# First, we load the wine dataset and convert it to a binary classification | ||
# problem. Then, we train a support vector classifier on a training dataset. | ||
import matplotlib.pyplot as plt | ||
from sklearn.svm import SVC | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.metrics import plot_roc_curve | ||
from sklearn.datasets import load_wine | ||
from sklearn.model_selection import train_test_split | ||
|
||
X, y = load_wine(return_X_y=True) | ||
y = y == 2 | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) | ||
svc = SVC(random_state=42) | ||
svc.fit(X_train, y_train) | ||
|
||
############################################################################## | ||
# Plotting the ROC Curve | ||
# ---------------------- | ||
# Next, we plot the ROC curve with a single call to | ||
# :func:`sklearn.metrics.plot_roc_curve`. The returned `svc_disp` object allows | ||
# us to continue using the already computed ROC curve for the SVC in future | ||
# plots. | ||
svc_disp = plot_roc_curve(svc, X_test, y_test) | ||
plt.show() | ||
|
||
############################################################################## | ||
# Training a Random Forest and Plotting the ROC Curve | ||
# -------------------------------------------------------- | ||
# We train a random forest classifier and create a plot comparing it to the SVC | ||
# ROC curve. Notice how `svc_disp` uses | ||
# :func:`~sklearn.metrics.RocCurveDisplay.plot` to plot the SVC ROC curve | ||
# without recomputing the values of the roc curve itself. Futhermore, we | ||
# pass `alpha=0.8` to the plot functions to adjust the alpha values of the | ||
# curves. | ||
rfc = RandomForestClassifier(n_estimators=10, random_state=42) | ||
rfc.fit(X_train, y_train) | ||
ax = plt.gca() | ||
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8) | ||
svc_disp.plot(ax=ax, alpha=0.8) | ||
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to actually have a very simple example that illustrates this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
link to example of user guide now?