-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH check_scoring()
has raise_exc
for multimetric scoring
#28992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
222110f
7a657ab
d1efd10
2cef99d
9cd74d9
8aecd28
f4b68b7
549e764
395d0ae
d9bf3a4
a3c6469
2a30caf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -955,10 +955,11 @@ def get_scorer_names(): | |
None, | ||
], | ||
"allow_none": ["boolean"], | ||
"raise_exc": ["boolean"], | ||
}, | ||
prefer_skip_nested_validation=True, | ||
) | ||
def check_scoring(estimator=None, scoring=None, *, allow_none=False): | ||
def check_scoring(estimator=None, scoring=None, *, allow_none=False, raise_exc=True): | ||
"""Determine scorer from user options. | ||
|
||
A TypeError will be thrown if the estimator cannot be scored. | ||
|
@@ -969,30 +970,43 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False): | |
The object to use to fit the data. If `None`, then this function may error | ||
depending on `allow_none`. | ||
|
||
scoring : str, callable, list, tuple, or dict, default=None | ||
scoring : str, callable, list, tuple, set, or dict, default=None | ||
Scorer to use. If `scoring` represents a single score, one can use: | ||
|
||
- a single string (see :ref:`scoring_parameter`); | ||
- a callable (see :ref:`scoring`) that returns a single value. | ||
|
||
If `scoring` represents multiple scores, one can use: | ||
|
||
- a list or tuple of unique strings; | ||
- a callable returning a dictionary where the keys are the metric | ||
names and the values are the metric scorers; | ||
- a dictionary with metric names as keys and callables a values. | ||
- a list, tuple or set of unique strings; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we replace There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, I actually think we should tell the users explicitly what the can pass in. Any other than those wouldn't show the correct behaviour in li. 1054-56: if isinstance(scoring, (list, tuple, set, dict)):
scorers = _check_multimetric_scoring(estimator, scoring=scoring)
return _MultimetricScorer(scorers=scorers, raise_exc=raise_exc) And, any other iterable of strings would raise during the So I would rather leave it as it is. |
||
- a callable returning a dictionary where the keys are the metric names and the | ||
values are the metric scorers; | ||
- a dictionary with metric names as keys and callables a values. The callables | ||
need to have the signature `callable(estimator, X, y)`. | ||
|
||
If None, the provided estimator object's `score` method is used. | ||
|
||
allow_none : bool, default=False | ||
If no scoring is specified and the estimator has no score function, we | ||
can either return None or raise an exception. | ||
Whether to return None or raise an error if no `scoring` is specified and the | ||
estimator has no `score` method. | ||
|
||
raise_exc : bool, default=True | ||
Whether to raise an exception (if a subset of the scorers in multimetric scoring | ||
fails) or to return an error code. | ||
|
||
- If set to `True`, raises the failing scorer's exception. | ||
- If set to `False`, a formatted string of the exception details is passed as | ||
result of the failing scorer(s). | ||
|
||
This applies if `scoring` is list, tuple, set, or dict. Ignored if `scoring` is | ||
a str or a callable. | ||
StefanieSenger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. versionadded:: 1.6 | ||
|
||
Returns | ||
------- | ||
scoring : callable | ||
A scorer callable object / function with signature | ||
``scorer(estimator, X, y)``. | ||
A scorer callable object / function with signature ``scorer(estimator, X, y)``. | ||
|
||
Examples | ||
-------- | ||
|
@@ -1004,6 +1018,19 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False): | |
>>> scorer = check_scoring(classifier, scoring='accuracy') | ||
>>> scorer(classifier, X, y) | ||
0.96... | ||
|
||
>>> from sklearn.metrics import make_scorer, accuracy_score, mean_squared_log_error | ||
>>> X, y = load_iris(return_X_y=True) | ||
>>> y *= -1 | ||
>>> clf = DecisionTreeClassifier().fit(X, y) | ||
>>> scoring = { | ||
... "accuracy": make_scorer(accuracy_score), | ||
... "mean_squared_log_error": make_scorer(mean_squared_log_error), | ||
... } | ||
>>> scoring_call = check_scoring(estimator=clf, scoring=scoring, raise_exc=False) | ||
>>> scores = scoring_call(clf, X, y) | ||
>>> scores | ||
{'accuracy': 1.0, 'mean_squared_log_error': 'Traceback ...'} | ||
""" | ||
if isinstance(scoring, str): | ||
return get_scorer(scoring) | ||
|
@@ -1026,7 +1053,7 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False): | |
return get_scorer(scoring) | ||
if isinstance(scoring, (list, tuple, set, dict)): | ||
scorers = _check_multimetric_scoring(estimator, scoring=scoring) | ||
return _MultimetricScorer(scorers=scorers) | ||
return _MultimetricScorer(scorers=scorers, raise_exc=raise_exc) | ||
if scoring is None: | ||
if hasattr(estimator, "score"): | ||
return _PassthroughScorer(estimator) | ||
|
Uh oh!
There was an error while loading. Please reload this page.