Skip to content

ENH use cv_results in the different curve display to add confidence intervals #21211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

glemaitre
Copy link
Member

@glemaitre glemaitre commented Oct 1, 2021

This PR intends to add the capability of plotting uncertainty of the different curves (calibration, precision-recall, roc, etc.) by using the results of cross-validation (i.e. the output of cross_validate).

TODO:

  • add a parameter return_indices in cross_validate to store the train-test indices. It is the safest way to keep track of the train-test splits in the case of stochastic splitting strategies.
  • add a method from_cv_results in the plotting display to take advantage of the CV computation.
  • add unit test for from_cv_results
  • add unit test for the new keyword parameters in CalibrationDisplay
  • add unit test for the new strategy of binning in calibration_curve

Usage example

# %%
import numpy as np
from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=10_000, weights=[0.1, 0.9], random_state=42, class_sep=1
)
sample_weight = np.zeros_like(y, dtype=np.float64)
sample_weight[y == 0] = 0.1
sample_weight[y == 1] = 0.9

# %%
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test, sw_train, sw_test = train_test_split(
    X, y, sample_weight, random_state=42
)

# %%
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV

calibration_method = "isotonic"
models = {
    "LR no weights": LogisticRegression(),
    "LR class weights": LogisticRegression(class_weight="balanced"),
    "Calibrated LR no weights": CalibratedClassifierCV(
        LogisticRegression(),
        method=calibration_method,
    ),
    "Calibrated LR class weights": CalibratedClassifierCV(
        LogisticRegression(class_weight="balanced"),
        method=calibration_method,
    ),
    "Calibrated LR sample weights": CalibratedClassifierCV(
        LogisticRegression(),
        method=calibration_method,
    ),
    "Calibrated LR class and sample weights": CalibratedClassifierCV(
        LogisticRegression(class_weight="balanced"),
        method=calibration_method,
    ),
}

# %%
import matplotlib.pyplot as plt
from sklearn.calibration import CalibrationDisplay
from sklearn.metrics import balanced_accuracy_score

fig, ax = plt.subplots()

calibration_display_params = {
    "n_bins": 20,
    "strategy": "quantile",
}
for name, model in models.items():
    if "sample weights" in name:
        model.fit(X_train, y_train, sample_weight=sw_train)
    else:
        model.fit(X_train, y_train)

    score = balanced_accuracy_score(y_test, model.predict(X_test))
    CalibrationDisplay.from_estimator(
        model,
        X_test,
        y_test,
        name=name + f" - {score:.3f}",
        ax=ax,
        **calibration_display_params,
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), title="Model - Balanced Accuracy")
_ = fig.suptitle(f"Using {calibration_method} calibration")

# %%
from sklearn.model_selection import cross_validate
from sklearn.model_selection import KFold

cv_results = {}
cv = KFold(n_splits=5)
for name, model in models.items():
    if "sample weights" in name:
        fit_params = {"sample_weight": sample_weight}
    else:
        fit_params = {}
    cv_results[name] = cross_validate(
        model,
        X,
        y,
        cv=cv,
        fit_params=fit_params,
        scoring="balanced_accuracy",
        return_estimator=True,
        return_indices=True,
    )

# %%
fig, ax = plt.subplots()
for model_idx, (name, results) in enumerate(cv_results.items()):
    CalibrationDisplay.from_cv_results(
        results, X, y, ax=ax, name=name, **calibration_display_params
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# %%
fig, ax = plt.subplots()
for model_idx, (name, results) in enumerate(cv_results.items()):
    CalibrationDisplay.from_cv_results(
        results,
        X,
        y,
        ax=ax,
        name=name,
        plot_uncertainty_style="fill_between",
        **calibration_display_params,
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# %%
fig, ax = plt.subplots()
for model_idx, (name, results) in enumerate(cv_results.items()):
    CalibrationDisplay.from_cv_results(
        results,
        X,
        y,
        ax=ax,
        name=name,
        plot_uncertainty_style="lines",
        **calibration_display_params,
    )
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# %%

001
002
003
004

@glemaitre glemaitre changed the title ENH use uncertainty estimate ENH use cv_results in the different curve display to add confidence intervals Oct 1, 2021
@glemaitre glemaitre marked this pull request as draft October 1, 2021 12:45
@ogrisel ogrisel self-requested a review October 19, 2021 09:18
@@ -1067,6 +1135,22 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
If `True`, plots a reference line representing a perfectly
calibrated classifier.

plot_uncertainty_style : {"errorbar", "fill_between", "lines"}, \
default="errorbar"
Copy link
Member

@ogrisel ogrisel Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the default should plot_uncertainty_style="lines" as it's the easier to understand without being mislead. For plot_uncertainty_style="errorbar" and plot_uncertainty_style="fill_between" we need to know that it's based on the raw standard deviation (as opposed to a pseudo confidence interval based on the standard error of the mean for instance).

Copy link
Member

@ogrisel ogrisel Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also accept plot_uncertainty_style=None to only plot the mean CV calibration curve without any uncertainty markers on the plot.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also plot_uncertainty_style="shade" or plot_uncertainty_style="shaded_area" might be easier to understand than plot_uncertainty_style="fill_between".

default="errorbar"
Style to plot the uncertainty information. Possibilities are:

- "errorbar": error bars representing one standard deviation;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two standard deviations: 1 above and 1 below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume (I did not check ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and I think I am right:

import numpy as np
import matplotlib.pyplot as plt


plt.errorbar(np.arange(5), np.ones(5), np.ones(5))

image

Comment on lines +173 to +174
return_indices : bool, default=False
Whether to return the train-test indices selected for each split.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming from #21664, I agree return_indices is useful. (I wanted to do something like this recently).

@adrinjalali
Copy link
Member

@glemaitre this seems cool to be continued!

@glemaitre
Copy link
Member Author

Yep this also pet of the CZI proposal on inspection. This would be my next effort after the tuning threshold classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Discussion
Development

Successfully merging this pull request may close these issues.

4 participants