Skip to content

ENH check_scoring() has raise_exc for multimetric scoring #28992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 24, 2024
7 changes: 7 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ Changelog
- |Enhancement| Added a function :func:`base.is_clusterer` which determines
whether a given estimator is of category clusterer.
:pr:`28936` by :user:`Christian Veenhuis <ChVeen>`.

:mod:`sklearn.metrics`
......................

- |Enhancement| :func:`sklearn.metrics.check_scoring` now accepts `raise_exc` to specify
whether to raise an exception if a subset of the scorers in multimetric scoring fails
or to return an error code. :pr:`28992` by :user:`Stefanie Senger <StefanieSenger>`.

Thanks to everyone who has contributed to the maintenance and improvement of
the project since version 1.5, including:
Expand Down
2 changes: 1 addition & 1 deletion sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class RFECV(RFE):
``cv`` default value of None changed from 3-fold to 5-fold.

scoring : str, callable or None, default=None
A string (see model evaluation documentation) or
A string (see :ref:`scoring_parameter`) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.

Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def _log_reg_scoring_path(
values are chosen in a logarithmic scale between 1e-4 and 1e4.

scoring : callable
A string (see model evaluation documentation) or
A string (see :ref:`scoring_parameter`) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``. For a list of scoring functions
that can be used, look at :mod:`sklearn.metrics`.
Expand Down Expand Up @@ -1521,7 +1521,7 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima
solver.

scoring : str or callable, default=None
A string (see model evaluation documentation) or
A string (see :ref:`scoring_parameter`) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``. For a list of scoring functions
that can be used, look at :mod:`sklearn.metrics`. The
Expand Down
49 changes: 38 additions & 11 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,10 +955,11 @@ def get_scorer_names():
None,
],
"allow_none": ["boolean"],
"raise_exc": ["boolean"],
},
prefer_skip_nested_validation=True,
)
def check_scoring(estimator=None, scoring=None, *, allow_none=False):
def check_scoring(estimator=None, scoring=None, *, allow_none=False, raise_exc=True):
"""Determine scorer from user options.

A TypeError will be thrown if the estimator cannot be scored.
Expand All @@ -969,30 +970,43 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False):
The object to use to fit the data. If `None`, then this function may error
depending on `allow_none`.

scoring : str, callable, list, tuple, or dict, default=None
scoring : str, callable, list, tuple, set, or dict, default=None
Scorer to use. If `scoring` represents a single score, one can use:

- a single string (see :ref:`scoring_parameter`);
- a callable (see :ref:`scoring`) that returns a single value.

If `scoring` represents multiple scores, one can use:

- a list or tuple of unique strings;
- a callable returning a dictionary where the keys are the metric
names and the values are the metric scorers;
- a dictionary with metric names as keys and callables a values.
- a list, tuple or set of unique strings;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we replace list, tuple, or set to iterable of strings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I actually think we should tell the users explicitly what the can pass in. Any other than those wouldn't show the correct behaviour in li. 1054-56:

    if isinstance(scoring, (list, tuple, set, dict)):
        scorers = _check_multimetric_scoring(estimator, scoring=scoring)
        return _MultimetricScorer(scorers=scorers, raise_exc=raise_exc)

And, any other iterable of strings would raise during the @validate_params().

So I would rather leave it as it is.

- a callable returning a dictionary where the keys are the metric names and the
values are the metric scorers;
- a dictionary with metric names as keys and callables a values. The callables
need to have the signature `callable(estimator, X, y)`.

If None, the provided estimator object's `score` method is used.

allow_none : bool, default=False
If no scoring is specified and the estimator has no score function, we
can either return None or raise an exception.
Whether to return None or raise an error if no `scoring` is specified and the
estimator has no `score` method.

raise_exc : bool, default=True
Whether to raise an exception (if a subset of the scorers in multimetric scoring
fails) or to return an error code.

- If set to `True`, raises the failing scorer's exception.
- If set to `False`, a formatted string of the exception details is passed as
result of the failing scorer(s).

This applies if `scoring` is list, tuple, set, or dict. Ignored if `scoring` is
a str or a callable.

.. versionadded:: 1.6

Returns
-------
scoring : callable
A scorer callable object / function with signature
``scorer(estimator, X, y)``.
A scorer callable object / function with signature ``scorer(estimator, X, y)``.

Examples
--------
Expand All @@ -1004,6 +1018,19 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False):
>>> scorer = check_scoring(classifier, scoring='accuracy')
>>> scorer(classifier, X, y)
0.96...

>>> from sklearn.metrics import make_scorer, accuracy_score, mean_squared_log_error
>>> X, y = load_iris(return_X_y=True)
>>> y *= -1
>>> clf = DecisionTreeClassifier().fit(X, y)
>>> scoring = {
... "accuracy": make_scorer(accuracy_score),
... "mean_squared_log_error": make_scorer(mean_squared_log_error),
... }
>>> scoring_call = check_scoring(estimator=clf, scoring=scoring, raise_exc=False)
>>> scores = scoring_call(clf, X, y)
>>> scores
{'accuracy': 1.0, 'mean_squared_log_error': 'Traceback ...'}
"""
if isinstance(scoring, str):
return get_scorer(scoring)
Expand All @@ -1026,7 +1053,7 @@ def check_scoring(estimator=None, scoring=None, *, allow_none=False):
return get_scorer(scoring)
if isinstance(scoring, (list, tuple, set, dict)):
scorers = _check_multimetric_scoring(estimator, scoring=scoring)
return _MultimetricScorer(scorers=scorers)
return _MultimetricScorer(scorers=scorers, raise_exc=raise_exc)
if scoring is None:
if hasattr(estimator, "score"):
return _PassthroughScorer(estimator)
Expand Down
28 changes: 28 additions & 0 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,34 @@ def test_multimetric_scorer_repr():
assert str(multi_metric_scorer) == 'MultiMetricScorer("accuracy", "r2")'


def test_check_scoring_multimetric_raise_exc():
"""Test that check_scoring returns error code for a subset of scorers in
multimetric scoring if raise_exc=False and raises otherwise."""

def raising_scorer(estimator, X, y):
raise ValueError("That doesn't work.")

X, y = make_classification(n_samples=150, n_features=10, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = LogisticRegression().fit(X_train, y_train)

# "raising_scorer" is raising ValueError and should return an string representation
# of the error of the last scorer:
scoring = {
"accuracy": make_scorer(accuracy_score),
"raising_scorer": raising_scorer,
}
scoring_call = check_scoring(estimator=clf, scoring=scoring, raise_exc=False)
scores = scoring_call(clf, X_test, y_test)
assert "That doesn't work." in scores["raising_scorer"]

# should raise an error
scoring_call = check_scoring(estimator=clf, scoring=scoring, raise_exc=True)
err_msg = "That doesn't work."
with pytest.raises(ValueError, match=err_msg):
scores = scoring_call(clf, X_test, y_test)


@pytest.mark.parametrize("enable_metadata_routing", [True, False])
def test_metadata_routing_multimetric_metadata_routing(enable_metadata_routing):
"""Test multimetric scorer works with and without metadata routing enabled when
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_classification_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ class TunedThresholdClassifierCV(BaseThresholdClassifier):
The objective metric to be optimized. Can be one of:

* a string associated to a scoring function for binary classification
(see model evaluation documentation);
(see :ref:`scoring_parameter`);
* a scorer callable object created with :func:`~sklearn.metrics.make_scorer`;

response_method : {"auto", "decision_function", "predict_proba"}, default="auto"
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
first_test_score = all_out[0]["test_scores"]
self.multimetric_ = isinstance(first_test_score, dict)

# check refit_metric now for a callabe scorer that is multimetric
# check refit_metric now for a callable scorer that is multimetric
if callable(self.scoring) and self.multimetric_:
self._check_refit_for_multimetric(first_test_score)
refit_metric = self.refit
Expand Down
14 changes: 4 additions & 10 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..base import clone, is_classifier
from ..exceptions import FitFailedWarning, UnsetMetadataPassedError
from ..metrics import check_scoring, get_scorer_names
from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer
from ..metrics._scorer import _MultimetricScorer
from ..preprocessing import LabelEncoder
from ..utils import Bunch, _safe_indexing, check_random_state, indexable
from ..utils._param_validation import (
Expand Down Expand Up @@ -353,15 +353,9 @@ def cross_validate(

cv = check_cv(cv, y, classifier=is_classifier(estimator))

if callable(scoring):
scorers = scoring
elif scoring is None or isinstance(scoring, str):
scorers = check_scoring(estimator, scoring)
else:
scorers = _check_multimetric_scoring(estimator, scoring)
scorers = _MultimetricScorer(
scorers=scorers, raise_exc=(error_score == "raise")
)
scorers = check_scoring(
estimator, scoring=scoring, raise_exc=(error_score == "raise")
)

if _routing_enabled():
# For estimators, a MetadataRouter is created in get_metadata_routing
Expand Down
Loading