Skip to content

MNT Refactor scorer using _get_response #18589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a212a81
MNT Refactor scorer using _get_response
glemaitre Oct 9, 2020
5737d34
fix error message
glemaitre Oct 9, 2020
1ca5e3a
FIX making error message consistent
glemaitre Oct 10, 2020
563de44
remove automatic import pylance
glemaitre Oct 10, 2020
7ac5d22
iter
glemaitre Oct 10, 2020
8dcb2be
iter
glemaitre Oct 10, 2020
0228384
move the function in a better module named
glemaitre Oct 10, 2020
25e8693
use _check_response_method in stacking estimators
glemaitre Oct 10, 2020
68fdce1
use _check_response_method in partial_dependence
glemaitre Oct 10, 2020
f4c224c
minimal changes
glemaitre Oct 10, 2020
e6a597c
TST pending test
glemaitre Oct 10, 2020
e047220
iter
glemaitre Oct 12, 2020
ec69cc6
TST add test for check_response_method
glemaitre Oct 12, 2020
a6aa2b5
regex
glemaitre Oct 12, 2020
efa6db9
iter
glemaitre Oct 12, 2020
40f09b8
TST add test for _get_response
glemaitre Oct 12, 2020
6a9e1d8
Merge remote-tracking branch 'origin/master' into is/18212_alternate
glemaitre Oct 19, 2020
09f5714
ENH accept a list of str
glemaitre Oct 19, 2020
e6c1953
TST adapt test
glemaitre Oct 19, 2020
3dfff72
make _get_response follow the same API
glemaitre Oct 19, 2020
a9ad549
Merge remote-tracking branch 'origin/master' into is/18212_alternate
glemaitre Oct 21, 2020
7de919d
Merge remote-tracking branch 'origin/master' into is/18212_alternate
glemaitre Jan 22, 2021
c29e090
revert import
glemaitre Jan 22, 2021
82b2086
iter
glemaitre Jan 22, 2021
b61561e
iter
glemaitre Jan 22, 2021
08efd84
Merge remote-tracking branch 'origin/main' into is/18212_alternate
glemaitre Jan 22, 2021
d96a653
iter
glemaitre Jan 22, 2021
218447d
xxx
glemaitre Jan 22, 2021
0121d4f
reorder
glemaitre Jan 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils.metaestimators import if_delegate_has_method
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils.validation import _check_response_method
from ..utils.validation import column_or_1d
from ..utils.validation import _deprecate_positional_args
from ..utils.fixes import delayed
Expand Down Expand Up @@ -96,18 +97,14 @@ def _concatenate_predictions(self, X, predictions):
def _method_name(name, estimator, method):
if estimator == 'drop':
return None
if method == 'auto':
if getattr(estimator, 'predict_proba', None):
return 'predict_proba'
elif getattr(estimator, 'decision_function', None):
return 'decision_function'
else:
return 'predict'
else:
if not hasattr(estimator, method):
raise ValueError('Underlying estimator {} does not implement '
'the method {}.'.format(name, method))
return method
method = None if method == "auto" else method
try:
method_name = _check_response_method(estimator, method).__name__
except ValueError as e:
raise ValueError(
f"stack_method {method} not defined in {name}"
) from e
return method_name

def fit(self, X, y, sample_weight=None):
"""Fit the estimators.
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def fit(self, X, y):
{'estimators': [('lr', LogisticRegression()),
('svm', SVC(max_iter=5e4))],
'stack_method': 'predict_proba'},
ValueError, 'does not implement the method predict_proba'),
ValueError, 'stack_method predict_proba not defined in svm'),
(y_iris,
{'estimators': [('lr', LogisticRegression()),
('cor', NoWeightClassifier())]},
Expand Down
47 changes: 16 additions & 31 deletions sklearn/inspection/_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from ..utils import _get_column_indices
from ..utils.validation import check_is_fitted
from ..utils import Bunch
from ..utils.validation import _deprecate_positional_args
from ..utils.validation import (
_check_response_method,
_deprecate_positional_args,
)
from ..tree import DecisionTreeRegressor
from ..ensemble import RandomForestRegressor
from ..exceptions import NotFittedError
Expand Down Expand Up @@ -120,29 +123,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method):
predictions = []
averaged_predictions = []

# define the prediction_method (predict, predict_proba, decision_function).
if is_regressor(est):
prediction_method = est.predict
else:
predict_proba = getattr(est, 'predict_proba', None)
decision_function = getattr(est, 'decision_function', None)
if response_method == 'auto':
# try predict_proba, then decision_function if it doesn't exist
prediction_method = predict_proba or decision_function
else:
prediction_method = (predict_proba if response_method ==
'predict_proba' else decision_function)
if prediction_method is None:
if response_method == 'auto':
raise ValueError(
'The estimator has no predict_proba and no '
'decision_function method.'
)
elif response_method == 'predict_proba':
raise ValueError('The estimator has no predict_proba method.')
else:
raise ValueError(
'The estimator has no decision_function method.')
prediction_method = _check_response_method(est, response_method)

for new_values in grid:
X_eval = X.copy()
Expand Down Expand Up @@ -406,11 +387,15 @@ def partial_dependence(estimator, X, features, *, response_method='auto',
'response_method {} is invalid. Accepted response_method names '
'are {}.'.format(response_method, ', '.join(accepted_responses)))

if is_regressor(estimator) and response_method != 'auto':
raise ValueError(
"The response_method parameter is ignored for regressors and "
"must be 'auto'."
)
if is_regressor(estimator):
if response_method != "auto":
raise ValueError(
"The response_method parameter is ignored for regressors and "
"must be 'auto'."
)
response_method = "predict"
elif response_method == "auto":
response_method = ["predict_proba", "decision_function"]

accepted_methods = ('brute', 'recursion', 'auto')
if method not in accepted_methods:
Expand Down Expand Up @@ -454,10 +439,10 @@ def partial_dependence(estimator, X, features, *, response_method='auto',
"Only the following estimators support the 'recursion' "
"method: {}. Try using method='brute'."
.format(', '.join(supported_classes_recursion)))
if response_method == 'auto':
if isinstance(response_method, list):
response_method = 'decision_function'

if response_method != 'decision_function':
if is_classifier(estimator) and response_method != 'decision_function':
raise ValueError(
"With the 'recursion' method, the response_method must be "
"'decision_function'. Got {}.".format(response_method)
Expand Down
11 changes: 6 additions & 5 deletions sklearn/inspection/tests/test_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,9 @@ def test_partial_dependence_helpers(est, method, target_feature):
[123]])

if method == 'brute':
pdp, predictions = _partial_dependence_brute(est, grid, features, X,
response_method='auto')
pdp, predictions = _partial_dependence_brute(
est, grid, features, X, response_method='predict'
)
else:
pdp = _partial_dependence_recursion(est, grid, features)

Expand Down Expand Up @@ -415,13 +416,13 @@ def fit(self, X, y):
'response_method blahblah is invalid. Accepted response_method'),
(NoPredictProbaNoDecisionFunction(),
{'features': [0], 'response_method': 'auto'},
'The estimator has no predict_proba and no decision_function method'),
'response_method predict_proba, decision_function not defined'),
(NoPredictProbaNoDecisionFunction(),
{'features': [0], 'response_method': 'predict_proba'},
'The estimator has no predict_proba method.'),
'response_method predict_proba not defined'),
(NoPredictProbaNoDecisionFunction(),
{'features': [0], 'response_method': 'decision_function'},
'The estimator has no decision_function method.'),
'response_method decision_function not defined'),
(LinearRegression(),
{'features': [0], 'method': 'blahblah'},
'blahblah is invalid. Accepted method names are brute, recursion, auto'),
Expand Down
114 changes: 0 additions & 114 deletions sklearn/metrics/_plot/base.py

This file was deleted.

18 changes: 14 additions & 4 deletions sklearn/metrics/_plot/det_curve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import scipy as sp

from .base import _get_response
from ...base import is_classifier
from ...utils import (
check_matplotlib_support,
_get_response,
)

from .. import det_curve

from ...utils import check_matplotlib_support


class DetCurveDisplay:
"""DET curve visualization.
Expand Down Expand Up @@ -209,8 +211,16 @@ def plot_det_curve(
"""
check_matplotlib_support('plot_det_curve')

if not is_classifier(estimator):
raise ValueError(
f"{estimator.__class__.__name__} should be a binary classifier."
)

if response_method == "auto":
response_method = ["predict_proba", "decision_function"]

y_pred, pos_label = _get_response(
X, estimator, response_method, pos_label=pos_label
estimator, X, y, response_method, pos_label=pos_label
)

fpr, fnr, _ = det_curve(
Expand Down
19 changes: 15 additions & 4 deletions sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base import _get_response

from .. import average_precision_score
from .. import precision_recall_curve

from ...utils import check_matplotlib_support
from ...base import is_classifier
from ...utils import (
check_matplotlib_support,
_get_response,
)
from ...utils.validation import _deprecate_positional_args


Expand Down Expand Up @@ -202,8 +204,17 @@ def plot_precision_recall_curve(estimator, X, y, *,
"""
check_matplotlib_support("plot_precision_recall_curve")

if not is_classifier(estimator):
raise ValueError(
f"{estimator.__class__.__name__} should be a binary classifier."
)

if response_method == "auto":
response_method = ["predict_proba", "decision_function"]

y_pred, pos_label = _get_response(
X, estimator, response_method, pos_label=pos_label)
estimator, X, y, response_method, pos_label=pos_label
)

precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
Expand Down
19 changes: 15 additions & 4 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base import _get_response

from .. import auc
from .. import roc_curve

from ...utils import check_matplotlib_support
from ...base import is_classifier
from ...utils import (
check_matplotlib_support,
_get_response,
)
from ...utils.validation import _deprecate_positional_args


Expand Down Expand Up @@ -209,8 +211,17 @@ def plot_roc_curve(estimator, X, y, *, sample_weight=None,
"""
check_matplotlib_support('plot_roc_curve')

if not is_classifier(estimator):
raise ValueError(
f"{estimator.__class__.__name__} should be a binary classifier."
)

if response_method == "auto":
response_method = ["predict_proba", "decision_function"]

y_pred, pos_label = _get_response(
X, estimator, response_method, pos_label=pos_label)
estimator, X, y, response_method, pos_label=pos_label
)

fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label,
sample_weight=sample_weight,
Expand Down
Loading