Skip to content

[MRG] Plotting API starting with ROC curve #14357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Aug 8, 2019

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Jul 14, 2019

(Post Updated 08/05/19)

Reference Issues/PRs

Related to

What does this implement/fix? Explain your changes.

Scikit-learn defines a simple API for creating visualizations for machine learning. The key features of this API is to run calculations once and to have the flexibility to adjust the visualizations after the fact. This logic is encapsulated into a display object where the computed data is stored and the plotting is done in a plot method. The display object's __init__ method contains only the data needed to create the visualization. The plot method takes in parameters that only have to do with visualization, such as a matplotlib axes. The plot method will store the matplotlib artists as attributes allowing for style adjustments through the display object. A plot_* helper function accepts parameters to do the computation and the parameters used for plotting. After the helper function creates the display object with the computed values, it calls the display's plot method. Note that the plot method defines attributes related to matplotlib, such as the line artist. This allows for customizations after calling the plot method.

For example, the RocCurveDisplay defines the following methods:

class RocCurveDisplay:
    def __init__(self, fpr, tpr, roc_auc, estimator_name):
        ...
         self.fpr = fpr
         self.tpr = tpr
         self.roc_auc = roc_auc
         self.estimator_name = estimator_name


    def plot(self, ax=None, name=None, **kwargs):
        ...
        self.line_ = ...
        self.ax_ = ax
        self.figure_ = ax.figure_

Together with a plotting function to create this object:

def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
                   drop_intermediate=True, response_method="auto",
                   name=None, ax=None, **kwargs):
       # do computation
       viz = RocCurveDisplay(fpr, tpr, roc_auc, 
                             estimator.__class__.__name__)
       return viz.plot(ax=ax, name=name, **kwargs)

Any other comments?

Here is a notebook demonstrating a workflow for using this API.

Here are what the API implemenation looks like for other visualizations:

@thomasjpfan thomasjpfan changed the title [MRG] Plotting API (ROC curve) [MRG] Plotting API starting with ROC curve Jul 14, 2019
Plotting metrics
----------------

.. automodule:: sklearn.metrics.plot
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we've decide to use sklearn.XXX.plot.plot_XXX, right?
We'll put things like plot_decision_boundary in sklearn.inspection, right?

(default='predict_proba')
Method to call estimator to get target scores

label : str or None, optional (default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used :)

@qinhanmin2014
Copy link
Member

So we've decided to:
(1) Use sklearn.XXX.plot.plot_XXX, do not allow from sklearn.XXX import plot_XXX
(2) Introduce a Visualizer class and a plot function. The Visualizer class do calculation in __init__ and do plotting in plot.

I think this is a great solution, but I'd like to see more opinions here (>=+3 maybe). Are you able to reach out to more people during the sprint, or maybe a SLEP/mail in the mailing list?

You need to change our matplotlib dependency in readme and install.rst to something like classes end with Visualizer and functions start with plot_.

We need to take care of plot_partial_dependence, right? Maybe it's acceptable to leave plot_tree as it is?

@qinhanmin2014 qinhanmin2014 added the Needs Decision Requires decision label Jul 14, 2019
@thomasjpfan
Copy link
Member Author

CC @NicolasHug

@amueller
Copy link
Member

There's a reimplementation of plot_partial_dependence linked above. I think it's backwards compatible apart from the returned object (which is a bit tricky).

I think it's worth emphasizing that the common simple use-case involves just calling the function, and the user not having to worry about the object. The main motivation for the object is to allow replotting without recomputation.

Maybe if we want to make the SLEPs more lightweight it might make sense to write a slep for this, I'm not sure.

I listed alternatives in #13448 (#13448 (comment))

There's some discussion there as well.

@qinhanmin2014
Copy link
Member

Regarding plot_roc_curve, will it be useful to return the roc_auc_score?

@amueller
Copy link
Member

Maybe to summarize, we could:

  1. have a plot_X function that takes the results of some computation. Upside: clear separation of computation and testing, backward compatible, pretty obvious interface. Downside: users have to call two functions. (say roc_curve and plot_roc_curve) and need to ensure the arguments match. (an alternative would be to have roc_curve return some ROC object, but that wouldn't be backward compatible and also seems not scikit-learn-like).

  2. have a plot_X function that does the computation inside. Upside: simple interface for users. Downside: mixing plotting and computation somewhat. Without further work: can't adjust existing plots easily, can't plot or do anything else with computations and potentially need to compute results again just do show them again.

I like being able to plot a roc curve with a single line. I think that should be our goal, so I'm strongly in favor of 2. Having to recompute, say, partial dependence, for plotting it again is not acceptable to me, though. That means we need to return or store the results of the computation somehow.

There's two immediate solutions (maybe there's better ones that I can't think of):
a) have the plotting one-liner be an object that does the computation and plotting in the __init__ and then stores everything in self and can be plotted again.
b) have a utility function that does the plotting and returns an object that stores the results and allows plotting again.

We opted for b) here because it hides the complexity somewhat from the user. The simple case still remains simple with a function call, not a class construction, which seems more natural to me (the difference is a bit academic, though, for the user it mostly manifests in whether we use CamelCase or snake_case for the thing they call (we could also have plot_roc_curve be a class and just make it look like a function if we wanted to).

@amueller
Copy link
Member

docstring tests are failing @thomasjpfan

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an example in the gallery (or we can modify existing example), at least for plot_roc_curve.

return viz
```

Note that the ``__init__`` method defines attributes that are going to be used
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this paragraph can be integrated into the previous one IMO.

@amueller
Copy link
Member

There's two roc curve examples, one for multi-class and one for cross-validation.
I suggest we merge the multi-class roc_auc first (though that only relates to the multi-class example of course).

The cross-validation example we can pretty easily do with the new function, apart from computing the mean.
For the mean, we could either:

  1. just not plot it (I vote this)
  2. in the end pull the results out of all the objects and do the interpolation bit
  3. add support to the plotting method directly (probably not)

For the multi-class one, I would use the plot function and manually apply it to each label for plotting in OVR fashion, and use the newly added multi-class roc to actually compute the metric.
Maybe showing the different kinds of metrics might be nice, but we can always add this later.

return self


def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's only one artist, I think we don't need to have a dict as an argument and can just do **kwargs instead of having line_kw={...}

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, current API design seems strange, I think the visualizer should have a consistent interface compared with the helper function, i.e., it should accept estimator&X&y instead of fpr&tpr, otherwise the visualizer seems useless.

import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier

rfc = DecisionTreeClassifier(random_state=42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RandomForestClassifier?

def __init__(self, fpr, tpr, auc_roc, estimator_name):
self.fpr = fpr
self.tpr = tpr
self.auc_roc = auc_roc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users need to pass auc_roc manually?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the computation is done in plot_roc_curve and passed into RocCurveDisplay.

self.fpr = fpr
self.tpr = tpr
self.auc_roc = auc_roc
self.estimator_name = estimator_name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant, we already have name in plot?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives the plot a default to name to display in the legend. This is useful when plotting multiple roc curves on the same axes.

The user can later adjust the name by passing name into plot.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done because how plot does custom formatting to the label to add the AUC score in the label for convenience.


if y_pred.ndim != 1:
if y_pred.shape[1] > 2:
raise ValueError("Estimator must be a binary classifier")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary classifier seems strange, maybe things like binary classification problem.

@GaelVaroquaux
Copy link
Member

Naming sugestion: changer "Visualizer" to "Display": "RocCurveDisplay"

"""
if response_method != "auto":
prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should check that user provided prediction_method is one of the accepted values.

@NicolasHug
Copy link
Member

Hmm, current API design seems strange, I think the visualizer should have a consistent interface compared with the helper function, i.e., it should accept estimator&X&y instead of fpr&tpr, otherwise the visualizer seems useless.

@qinhanmin2014 the ultimate goal here is to enable changing the style of plots without having to recompute values.

We want to keep functions plot_... because we already have these and users are used to them. Should the visualizer have the same API as the function, that would require the visualizer to perform the computations in init. That would also mean there are 2 ways of doing the same thing (use a function or use an object), which we want to avoid.

@qinhanmin2014
Copy link
Member

We want to keep functions plot_... because we already have these and users are used to them. Should the visualizer have the same API as the function, that would require the visualizer to perform the computations in init. That would also mean there are 2 ways of doing the same thing (use a function or use an object), which we want to avoid.

Maybe it's common for a helper function to do similar things compared with the class in scikit-learn?

@NicolasHug
Copy link
Member

True, but I guess estimators are intended to be only used once.

The advantage of the current design is that if you realize you want to change the style of the plot after the fact, your code can stay (almost) the same. With your proposal uses have to replace the function call by an object instantiation.

@amueller
Copy link
Member

amueller commented Aug 5, 2019

@qinhanmin2014 raises the question in the dev meeting whether we need a visualizer for plot_tree. I don't think that's required but that makes the API for "plotting" somewhat inconsistent. Maybe that's a "different" kind of plotting functions but we should make sure that we have some consistent story.

@amueller
Copy link
Member

amueller commented Aug 5, 2019

I kinda like the Display idea by @GaelVaroquaux but not 100% sure.

@qinhanmin2014
Copy link
Member

I think Andy persuaded me during the meeting. I'll vote +1 here.

I don't think that's required but that makes the API for "plotting" somewhat inconsistent.

I agree that we don't need a visualizer for plot_tree, but the inconsistency is indeed a problem.
Maybe we can explain that for some plotting functions which do not require much calculation, we do not provide a visualizer?

I kinda like the Display idea by @GaelVaroquaux but not 100% sure.

+1 for either solution. I don't care much about the name.

Maybe @NicolasHug can organize a vote in the mailing list (since it's a brand new API in scikit-learn) and then we can merge this one.

@NicolasHug
Copy link
Member

waiting for @thomasjpfan to post an update and ping

(not sure why I should be the one calling for a vote though?)

@qinhanmin2014
Copy link
Member

(not sure why I should be the one calling for a vote though?)

because you organize the meeting, hmm, strange reason right :)
you can ask me to do this or maybe @thomasjpfan can better summarize this PR.

@thomasjpfan
Copy link
Member Author

I have updated this PR and the original post with the following summary:

Summary

The key features of this API is to run calculations once and to have the flexibility to adjust the visualizations after the fact. This logic is encapsulated into a display object where the computed data is stored and the plotting is done in a plot method. The display object's __init__ method contains only the data needed to create the visualization. The plot method takes in parameters that only have to do with visualization, such as a matplotlib axes. The plot method will store the matplotlib artists as attributes allowing for style adjustments through the display object. A plot_* helper function accepts parameters to do the computation and the parameters used for plotting. After the helper function creates the display object with the computed values, it calls the display's plot method. Note that the plot method defines attributes related to matplotlib, such as the line artist. This allows for customizations after calling the plot method.

For example, the RocCurveDisplay defines the following methods:

class RocCurveDisplay:
    def __init__(self, fpr, tpr, roc_auc, estimator_name):
        ...
         self.fpr = fpr
         self.tpr = tpr
         self.roc_auc = roc_auc
         self.estimator_name = estimator_name


    def plot(self, ax=None, name=None, **kwargs):
        ...
        self.line_ = ...
        self.ax_ = ax
        self.figure_ = ax.figure_

Together with a plotting function to create this object:

def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
                   drop_intermediate=True, response_method="auto",
                   name=None, ax=None, **kwargs):
       # do computation
       viz = RocCurveDisplay(fpr, tpr, roc_auc, 
                             estimator.__class__.__name__)
       return viz.plot(ax=ax, name=name, **kwargs)

@jnothman
Copy link
Member

jnothman commented Aug 5, 2019 via email

@adrinjalali
Copy link
Member

I'm not sure if a vote is going to help here, it'll complicate things and this PR would be merged much much later, if we go for that process, which also includes writing a SLEP.

@qinhanmin2014
Copy link
Member

So +1 from Joel, Andy, Adrin, Nicolas and me (maybe Gael not sure, sorry if I miss someone),
I think there's enough consensus here. @thomasjpfan please reply to my reviews (#14357 (review)), though they're very trivial :)
and then let's merge.

Available Plotting Utilities
============================

Fucntions

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling Error

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@jnothman
Copy link
Member

jnothman commented Aug 8, 2019 via email

@amueller
Copy link
Member

amueller commented Aug 8, 2019

The brave new world of plotting?

@jnothman
Copy link
Member

jnothman commented Aug 8, 2019 via email

@cmarmo cmarmo removed the Needs Decision Requires decision label Apr 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.