Skip to content

ENH _get_response_values handles predict_log_proba #27719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions sklearn/utils/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def _get_response_values(
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

response_method : {"predict_proba", "decision_function", "predict"} or \
list of such str
response_method : {"predict_proba", "predict_log_proba", "decision_function", \
"predict"} or list of such str
Specifies the response method to use get prediction from an estimator
(i.e. :term:`predict_proba`, :term:`decision_function` or
:term:`predict`). Possible choices are:
(i.e. :term:`predict_proba`, :term:`predict_log_proba`,
:term:`decision_function` or :term:`predict`). Possible choices are:

- if `str`, it corresponds to the name to the method to return;
- if a list of `str`, it provides the method names in order of
Expand Down Expand Up @@ -209,7 +209,7 @@ def _get_response_values(

y_pred = prediction_method(X)

if prediction_method.__name__ == "predict_proba":
if prediction_method.__name__ in ("predict_proba", "predict_log_proba"):
y_pred = _process_predict_proba(
y_pred=y_pred,
target_type=target_type,
Expand Down
26 changes: 17 additions & 9 deletions sklearn/utils/tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
X_binary, y_binary = X[:100], y[:100]


@pytest.mark.parametrize("response_method", ["decision_function", "predict_proba"])
@pytest.mark.parametrize(
"response_method", ["decision_function", "predict_proba", "predict_log_proba"]
)
def test_get_response_values_regressor_error(response_method):
"""Check the error message with regressor an not supported response
method."""
Expand Down Expand Up @@ -82,7 +84,7 @@ def test_get_response_values_outlier_detection(

@pytest.mark.parametrize(
"response_method",
["predict_proba", "decision_function", "predict"],
["predict_proba", "decision_function", "predict", "predict_log_proba"],
)
def test_get_response_values_classifier_unknown_pos_label(response_method):
"""Check that `_get_response_values` raises the proper error message with
Expand All @@ -101,7 +103,10 @@ def test_get_response_values_classifier_unknown_pos_label(response_method):
)


def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba():
@pytest.mark.parametrize("response_method", ["predict_proba", "predict_log_proba"])
def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(
response_method,
):
"""Check that `_get_response_values` will raise an error when `y_pred` has a
single class with `predict_proba`."""
X, y_two_class = make_classification(n_samples=10, n_classes=2, random_state=0)
Expand All @@ -113,7 +118,7 @@ def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba():
r"two classes"
)
with pytest.raises(ValueError, match=err_msg):
_get_response_values(classifier, X, response_method="predict_proba")
_get_response_values(classifier, X, response_method=response_method)


@pytest.mark.parametrize("return_response_method_used", [True, False])
Expand Down Expand Up @@ -159,8 +164,9 @@ def test_get_response_values_binary_classifier_decision_function(


@pytest.mark.parametrize("return_response_method_used", [True, False])
@pytest.mark.parametrize("response_method", ["predict_proba", "predict_log_proba"])
def test_get_response_values_binary_classifier_predict_proba(
return_response_method_used,
return_response_method_used, response_method
):
"""Check that `_get_response_values` with `predict_proba` and binary
classifier."""
Expand All @@ -171,7 +177,6 @@ def test_get_response_values_binary_classifier_predict_proba(
random_state=0,
)
classifier = LogisticRegression().fit(X, y)
response_method = "predict_proba"

# default `pos_label`
results = _get_response_values(
Expand All @@ -181,11 +186,11 @@ def test_get_response_values_binary_classifier_predict_proba(
pos_label=None,
return_response_method_used=return_response_method_used,
)
assert_allclose(results[0], classifier.predict_proba(X)[:, 1])
assert_allclose(results[0], getattr(classifier, response_method)(X)[:, 1])
assert results[1] == 1
if return_response_method_used:
assert len(results) == 3
assert results[2] == "predict_proba"
assert results[2] == response_method
else:
assert len(results) == 2

Expand All @@ -197,7 +202,7 @@ def test_get_response_values_binary_classifier_predict_proba(
pos_label=classifier.classes_[0],
return_response_method_used=return_response_method_used,
)
assert_allclose(y_pred, classifier.predict_proba(X)[:, 0])
assert_allclose(y_pred, getattr(classifier, response_method)(X)[:, 0])
assert pos_label == 0


Expand Down Expand Up @@ -271,6 +276,7 @@ def test_get_response_decision_function():
"estimator, response_method",
[
(DecisionTreeClassifier(max_depth=2, random_state=0), "predict_proba"),
(DecisionTreeClassifier(max_depth=2, random_state=0), "predict_log_proba"),
(LogisticRegression(), "decision_function"),
],
)
Expand All @@ -287,6 +293,8 @@ def test_get_response_values_multiclass(estimator, response_method):
assert predictions.shape == (X.shape[0], len(estimator.classes_))
if response_method == "predict_proba":
assert np.logical_and(predictions >= 0, predictions <= 1).all()
elif response_method == "predict_log_proba":
assert (predictions <= 0.0).all()


def test_get_response_values_with_response_list():
Expand Down
8 changes: 4 additions & 4 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,11 +1923,11 @@ def _check_response_method(estimator, response_method):
estimator : estimator instance
Classifier or regressor to check.

response_method : {"predict_proba", "decision_function", "predict"} or \
list of such str
response_method : {"predict_proba", "predict_log_proba", "decision_function",
"predict"} or list of such str
Specifies the response method to use get prediction from an estimator
(i.e. :term:`predict_proba`, :term:`decision_function` or
:term:`predict`). Possible choices are:
(i.e. :term:`predict_proba`, :term:`predict_log_proba`,
:term:`decision_function` or :term:`predict`). Possible choices are:
- if `str`, it corresponds to the name to the method to return;
- if a list of `str`, it provides the method names in order of
preference. The method returned corresponds to the first method in
Expand Down