diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 8f5552ffd6808..60377e3b10f66 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -1,5 +1,3 @@ -import numpy as np - from ...base import is_classifier @@ -91,14 +89,18 @@ def _get_response(X, estimator, response_method, pos_label=None): raise ValueError(classification_error) prediction_method = _check_classifier_response_method(estimator, response_method) - y_pred = prediction_method(X) - - if pos_label is not None and pos_label not in estimator.classes_: - raise ValueError( - "The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {estimator.classes_}" - ) + if pos_label is not None: + try: + class_idx = estimator.classes_.tolist().index(pos_label) + except ValueError as e: + raise ValueError( + "The class provided by 'pos_label' is unknown. Got " + f"{pos_label} instead of one of {set(estimator.classes_)}" + ) from e + else: + class_idx = 1 + pos_label = estimator.classes_[class_idx] if y_pred.ndim != 1: # `predict_proba` y_pred_shape = y_pred.shape[1] @@ -107,16 +109,8 @@ def _get_response(X, estimator, response_method, pos_label=None): f"{classification_error} fit on multiclass ({y_pred_shape} classes)" " data" ) - if pos_label is None: - pos_label = estimator.classes_[1] - y_pred = y_pred[:, 1] - else: - class_idx = np.flatnonzero(estimator.classes_ == pos_label) - y_pred = y_pred[:, class_idx] - else: - if pos_label is None: - pos_label = estimator.classes_[1] - elif pos_label == estimator.classes_[0]: - y_pred *= -1 + y_pred = y_pred[:, class_idx] + elif pos_label == estimator.classes_[0]: # `decision_function` + y_pred *= -1 return y_pred, pos_label diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py new file mode 100644 index 0000000000000..2f67d7dd223f4 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -0,0 +1,75 @@ +import numpy as np +import pytest + +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor + +from sklearn.metrics._plot.base import _get_response + + +@pytest.mark.parametrize( + "estimator, err_msg, params", + [ + ( + DecisionTreeRegressor(), + "Expected 'estimator' to be a binary classifier", + {"response_method": "auto"}, + ), + ( + DecisionTreeClassifier(), + "The class provided by 'pos_label' is unknown.", + {"response_method": "auto", "pos_label": "unknown"}, + ), + ( + DecisionTreeClassifier(), + "fit on multiclass", + {"response_method": "predict_proba"}, + ), + ], +) +def test_get_response_error(estimator, err_msg, params): + """Check that we raise the proper error messages in `_get_response`.""" + X, y = load_iris(return_X_y=True) + + estimator.fit(X, y) + with pytest.raises(ValueError, match=err_msg): + _get_response(X, estimator, **params) + + +def test_get_response_predict_proba(): + """Check the behaviour of `_get_response` using `predict_proba`.""" + X, y = load_iris(return_X_y=True) + X_binary, y_binary = X[:100], y[:100] + + classifier = DecisionTreeClassifier().fit(X_binary, y_binary) + y_proba, pos_label = _get_response( + X_binary, classifier, response_method="predict_proba" + ) + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) + assert pos_label == 1 + + y_proba, pos_label = _get_response( + X_binary, classifier, response_method="predict_proba", pos_label=0 + ) + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) + assert pos_label == 0 + + +def test_get_response_decision_function(): + """Check the behaviour of `get_response` using `decision_function`.""" + X, y = load_iris(return_X_y=True) + X_binary, y_binary = X[:100], y[:100] + + classifier = LogisticRegression().fit(X_binary, y_binary) + y_score, pos_label = _get_response( + X_binary, classifier, response_method="decision_function" + ) + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) + assert pos_label == 1 + + y_score, pos_label = _get_response( + X_binary, classifier, response_method="decision_function", pos_label=0 + ) + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) + assert pos_label == 0