diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index 8836491623ae1..e647ba3a4f009 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -143,11 +143,11 @@ def _get_response_values( X : {array-like, sparse matrix} of shape (n_samples, n_features) Input values. - response_method : {"predict_proba", "decision_function", "predict"} or \ - list of such str + response_method : {"predict_proba", "predict_log_proba", "decision_function", \ + "predict"} or list of such str Specifies the response method to use get prediction from an estimator - (i.e. :term:`predict_proba`, :term:`decision_function` or - :term:`predict`). Possible choices are: + (i.e. :term:`predict_proba`, :term:`predict_log_proba`, + :term:`decision_function` or :term:`predict`). Possible choices are: - if `str`, it corresponds to the name to the method to return; - if a list of `str`, it provides the method names in order of @@ -209,7 +209,7 @@ def _get_response_values( y_pred = prediction_method(X) - if prediction_method.__name__ == "predict_proba": + if prediction_method.__name__ in ("predict_proba", "predict_log_proba"): y_pred = _process_predict_proba( y_pred=y_pred, target_type=target_type, diff --git a/sklearn/utils/tests/test_response.py b/sklearn/utils/tests/test_response.py index e4464559ef066..c84bf6030336a 100644 --- a/sklearn/utils/tests/test_response.py +++ b/sklearn/utils/tests/test_response.py @@ -25,7 +25,9 @@ X_binary, y_binary = X[:100], y[:100] -@pytest.mark.parametrize("response_method", ["decision_function", "predict_proba"]) +@pytest.mark.parametrize( + "response_method", ["decision_function", "predict_proba", "predict_log_proba"] +) def test_get_response_values_regressor_error(response_method): """Check the error message with regressor an not supported response method.""" @@ -82,7 +84,7 @@ def test_get_response_values_outlier_detection( @pytest.mark.parametrize( "response_method", - ["predict_proba", "decision_function", "predict"], + ["predict_proba", "decision_function", "predict", "predict_log_proba"], ) def test_get_response_values_classifier_unknown_pos_label(response_method): """Check that `_get_response_values` raises the proper error message with @@ -101,7 +103,10 @@ def test_get_response_values_classifier_unknown_pos_label(response_method): ) -def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): +@pytest.mark.parametrize("response_method", ["predict_proba", "predict_log_proba"]) +def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba( + response_method, +): """Check that `_get_response_values` will raise an error when `y_pred` has a single class with `predict_proba`.""" X, y_two_class = make_classification(n_samples=10, n_classes=2, random_state=0) @@ -113,7 +118,7 @@ def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): r"two classes" ) with pytest.raises(ValueError, match=err_msg): - _get_response_values(classifier, X, response_method="predict_proba") + _get_response_values(classifier, X, response_method=response_method) @pytest.mark.parametrize("return_response_method_used", [True, False]) @@ -159,8 +164,9 @@ def test_get_response_values_binary_classifier_decision_function( @pytest.mark.parametrize("return_response_method_used", [True, False]) +@pytest.mark.parametrize("response_method", ["predict_proba", "predict_log_proba"]) def test_get_response_values_binary_classifier_predict_proba( - return_response_method_used, + return_response_method_used, response_method ): """Check that `_get_response_values` with `predict_proba` and binary classifier.""" @@ -171,7 +177,6 @@ def test_get_response_values_binary_classifier_predict_proba( random_state=0, ) classifier = LogisticRegression().fit(X, y) - response_method = "predict_proba" # default `pos_label` results = _get_response_values( @@ -181,11 +186,11 @@ def test_get_response_values_binary_classifier_predict_proba( pos_label=None, return_response_method_used=return_response_method_used, ) - assert_allclose(results[0], classifier.predict_proba(X)[:, 1]) + assert_allclose(results[0], getattr(classifier, response_method)(X)[:, 1]) assert results[1] == 1 if return_response_method_used: assert len(results) == 3 - assert results[2] == "predict_proba" + assert results[2] == response_method else: assert len(results) == 2 @@ -197,7 +202,7 @@ def test_get_response_values_binary_classifier_predict_proba( pos_label=classifier.classes_[0], return_response_method_used=return_response_method_used, ) - assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) + assert_allclose(y_pred, getattr(classifier, response_method)(X)[:, 0]) assert pos_label == 0 @@ -271,6 +276,7 @@ def test_get_response_decision_function(): "estimator, response_method", [ (DecisionTreeClassifier(max_depth=2, random_state=0), "predict_proba"), + (DecisionTreeClassifier(max_depth=2, random_state=0), "predict_log_proba"), (LogisticRegression(), "decision_function"), ], ) @@ -287,6 +293,8 @@ def test_get_response_values_multiclass(estimator, response_method): assert predictions.shape == (X.shape[0], len(estimator.classes_)) if response_method == "predict_proba": assert np.logical_and(predictions >= 0, predictions <= 1).all() + elif response_method == "predict_log_proba": + assert (predictions <= 0.0).all() def test_get_response_values_with_response_list(): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index a5b4a8555de63..742d0047fb08d 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1923,11 +1923,11 @@ def _check_response_method(estimator, response_method): estimator : estimator instance Classifier or regressor to check. - response_method : {"predict_proba", "decision_function", "predict"} or \ - list of such str + response_method : {"predict_proba", "predict_log_proba", "decision_function", + "predict"} or list of such str Specifies the response method to use get prediction from an estimator - (i.e. :term:`predict_proba`, :term:`decision_function` or - :term:`predict`). Possible choices are: + (i.e. :term:`predict_proba`, :term:`predict_log_proba`, + :term:`decision_function` or :term:`predict`). Possible choices are: - if `str`, it corresponds to the name to the method to return; - if a list of `str`, it provides the method names in order of preference. The method returned corresponds to the first method in