Skip to content

ENH add pos_label in calibration tools #21032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,22 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.


:mod:`sklearn.calibration`
..........................

- |Enhancement| :func:`calibration.calibration_curve` accepts a parameter
`pos_label` to specify the positive class label.
:pr:`21032` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.linear_model`
...........................

- |Fix| :class:`linear_model.LogisticRegression` now raises a better error
message when the solver does not support sparse matrices with int64 indices.
:pr:`21093` by `Tom Dupre la Tour`_.


:mod:`sklearn.utils`
....................

Expand Down
23 changes: 18 additions & 5 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@

from .utils.multiclass import check_classification_targets
from .utils.fixes import delayed
from .utils.validation import check_is_fitted, check_consistent_length
from .utils.validation import _check_sample_weight, _num_samples
from .utils.validation import (
_check_sample_weight,
_num_samples,
check_consistent_length,
check_is_fitted,
)
from .utils import _safe_indexing
from .isotonic import IsotonicRegression
from .svm import LinearSVC
from .model_selection import check_cv, cross_val_predict
from .metrics._base import _check_pos_label_consistency
from .metrics._plot.base import _get_response


Expand Down Expand Up @@ -866,7 +871,9 @@ def predict(self, T):
return expit(-(self.a_ * T + self.b_))


def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="uniform"):
def calibration_curve(
y_true, y_prob, *, pos_label=None, normalize=False, n_bins=5, strategy="uniform"
):
"""Compute true and predicted probabilities for a calibration curve.

The method assumes the inputs come from a binary classifier, and
Expand All @@ -884,6 +891,11 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un
y_prob : array-like of shape (n_samples,)
Probabilities of the positive class.

pos_label : int or str, default=None
The label of the positive class.

.. versionadded:: 1.1

normalize : bool, default=False
Whether y_prob needs to be normalized into the [0, 1] interval, i.e.
is not a proper probability. If True, the smallest value in y_prob
Expand Down Expand Up @@ -934,6 +946,7 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un
y_true = column_or_1d(y_true)
y_prob = column_or_1d(y_prob)
check_consistent_length(y_true, y_prob)
pos_label = _check_pos_label_consistency(pos_label, y_true)

if normalize: # Normalize predicted values into interval [0, 1]
y_prob = (y_prob - y_prob.min()) / (y_prob.max() - y_prob.min())
Expand All @@ -945,9 +958,9 @@ def calibration_curve(y_true, y_prob, *, normalize=False, n_bins=5, strategy="un
labels = np.unique(y_true)
if len(labels) > 2:
raise ValueError(
"Only binary classification is supported. Provided labels %s." % labels
f"Only binary classification is supported. Provided labels {labels}."
)
y_true = label_binarize(y_true, classes=labels)[:, 0]
y_true = y_true == pos_label

if strategy == "quantile": # Determine bin edges by distribution of data
quantiles = np.linspace(0, 1, n_bins + 1)
Expand Down
37 changes: 37 additions & 0 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,43 @@ def test_calibration_display_ref_line(pyplot, iris_data_binary):
assert labels.count("Perfectly calibrated") == 1


@pytest.mark.parametrize("dtype_y_str", [str, object])
def test_calibration_curve_pos_label_error_str(dtype_y_str):
"""Check error message when a `pos_label` is not specified with `str` targets."""
rng = np.random.RandomState(42)
y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str)
y2 = rng.randint(0, 2, size=y1.size)

err_msg = (
"y_true takes value in {'eggs', 'spam'} and pos_label is not "
"specified: either make y_true take value in {0, 1} or {-1, 1} or "
"pass pos_label explicitly"
)
with pytest.raises(ValueError, match=err_msg):
calibration_curve(y1, y2)


@pytest.mark.parametrize("dtype_y_str", [str, object])
def test_calibration_curve_pos_label(dtype_y_str):
"""Check the behaviour when passing explicitly `pos_label`."""
y_true = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1])
classes = np.array(["spam", "egg"], dtype=dtype_y_str)
y_true_str = classes[y_true]
y_pred = np.array([0.1, 0.2, 0.3, 0.4, 0.65, 0.7, 0.8, 0.9, 1.0])

# default case
prob_true, _ = calibration_curve(y_true, y_pred, n_bins=4)
assert_allclose(prob_true, [0, 0.5, 1, 1])
# if `y_true` contains `str`, then `pos_label` is required
prob_true, _ = calibration_curve(y_true_str, y_pred, n_bins=4, pos_label="egg")
assert_allclose(prob_true, [0, 0.5, 1, 1])

prob_true, _ = calibration_curve(y_true, 1 - y_pred, n_bins=4, pos_label=0)
assert_allclose(prob_true, [0, 0, 0.5, 1])
prob_true, _ = calibration_curve(y_true_str, 1 - y_pred, n_bins=4, pos_label="spam")
assert_allclose(prob_true, [0, 0, 0.5, 1])


@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
@pytest.mark.parametrize("ensemble", [True, False])
def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ensemble):
Expand Down