Skip to content

[MRG] Adds support for multimetric callable return a dictionary #15126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9d090da
WIP
thomasjpfan Oct 3, 2019
315c335
ENH Increase compability
thomasjpfan Oct 3, 2019
702cf1b
ENH Refactories _fit_and_score
thomasjpfan Oct 3, 2019
a7d2efb
RFC Moves support into a function
thomasjpfan Oct 4, 2019
c77afd7
BUG Fix old numpy bug
thomasjpfan Oct 4, 2019
5ab8693
TST Removes tests for error on multimetric
thomasjpfan Oct 4, 2019
9c53783
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Dec 3, 2019
8676c04
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Dec 4, 2019
e8f8c9f
DOC Indent
thomasjpfan Dec 4, 2019
5f50a32
CLN Refactors multimetric check
thomasjpfan Dec 4, 2019
ad829e1
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jan 7, 2020
c27f592
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Feb 4, 2020
57c390a
CLN Address comments
thomasjpfan Feb 4, 2020
524fd87
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Feb 5, 2020
1b28907
CLN Simplifies checking
thomasjpfan Feb 5, 2020
2cf9ba8
CLN Simplifies aggregation
thomasjpfan Feb 5, 2020
f336d64
CLN Less code the better
thomasjpfan Feb 5, 2020
a86eaf0
CLN Moves definition closer to usage
thomasjpfan Feb 5, 2020
b1782ae
CLN Update error handling
thomasjpfan Feb 6, 2020
d4782b2
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Feb 6, 2020
c14463c
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan May 24, 2020
762d644
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan May 24, 2020
c5f9b42
REV Less diffs
thomasjpfan May 24, 2020
0e79b59
CLN Address comments
thomasjpfan May 24, 2020
49e8c03
REV
thomasjpfan May 24, 2020
4f6ecd7
STY Flake
thomasjpfan May 24, 2020
4fa5eb6
ENH Fix error
thomasjpfan May 24, 2020
97b1db2
REV Less diffs
thomasjpfan May 25, 2020
286bb86
DOC Adds comments
thomasjpfan May 25, 2020
9799297
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan May 25, 2020
5e1b72b
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jun 6, 2020
fe7dae3
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jun 8, 2020
7c85d54
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jul 8, 2020
5da1571
CLN Removes some state
thomasjpfan Jul 8, 2020
e541de3
CLN Address comments
thomasjpfan Jul 9, 2020
e6116c5
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jul 9, 2020
657ef89
BUG Fix score
thomasjpfan Jul 9, 2020
b0cdc57
CLN Adds to glossary
thomasjpfan Jul 9, 2020
714372f
CLN Uses f-strings
thomasjpfan Jul 9, 2020
b09e303
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jul 12, 2020
83a4a76
CLN Address comments
thomasjpfan Jul 12, 2020
333ff25
STY Fix
thomasjpfan Jul 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1583,10 +1583,10 @@ functions or non-estimator constructors.
in the User Guide.

Where multiple metrics can be evaluated, ``scoring`` may be given
either as a list of unique strings or a dictionary with names as keys
and callables as values. Note that this does *not* specify which score
function is to be maximized, and another parameter such as ``refit``
maybe used for this purpose.
either as a list of unique strings, a dictionary with names as keys and
callables as values or a callable that returns a dictionary. Note that
this does *not* specify which score function is to be maximized, and
another parameter such as ``refit`` maybe used for this purpose.


The ``scoring`` parameter is validated and interpreted using
Expand Down
24 changes: 11 additions & 13 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Using multiple metric evaluation
Scikit-learn also permits evaluation of multiple metrics in ``GridSearchCV``,
``RandomizedSearchCV`` and ``cross_validate``.

There are two ways to specify multiple scoring metrics for the ``scoring``
There are three ways to specify multiple scoring metrics for the ``scoring``
parameter:

- As an iterable of string metrics::
Expand All @@ -261,25 +261,23 @@ parameter:
>>> scoring = {'accuracy': make_scorer(accuracy_score),
... 'prec': 'precision'}

Note that the dict values can either be scorer functions or one of the
predefined metric strings.
Note that the dict values can either be scorer functions or one of the
predefined metric strings.

Currently only those scorer functions that return a single score can be passed
inside the dict. Scorer functions that return multiple values are not
permitted and will require a wrapper to return a single metric::
- As a callable that returns a dictionary of scores::

>>> from sklearn.model_selection import cross_validate
>>> from sklearn.metrics import confusion_matrix
>>> # A sample toy binary classification dataset
>>> X, y = datasets.make_classification(n_classes=2, random_state=0)
>>> svm = LinearSVC(random_state=0)
>>> def tn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 0]
>>> def fp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 1]
>>> def fn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 0]
>>> def tp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 1]
>>> scoring = {'tp': make_scorer(tp), 'tn': make_scorer(tn),
... 'fp': make_scorer(fp), 'fn': make_scorer(fn)}
>>> cv_results = cross_validate(svm, X, y, cv=5, scoring=scoring)
>>> def confusion_matrix_scorer(clf, X, y):
... y_pred = clf.predict(X)
... cm = confusion_matrix(y, y_pred)
... return {'tn': cm[0, 0], 'fp': cm[0, 1],
... 'fn': cm[1, 0], 'tp': cm[1, 1]}
>>> cv_results = cross_validate(svm, X, y, cv=5,
... scoring=confusion_matrix_scorer)
>>> # Getting the test set true positive scores
>>> print(cv_results['test_tp'])
[10 9 8 7 8]
Expand Down
115 changes: 45 additions & 70 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,99 +423,74 @@ def check_scoring(estimator, scoring=None, *, allow_none=False):
" None. %r was passed" % scoring)


def _check_multimetric_scoring(estimator, scoring=None):
def _check_multimetric_scoring(estimator, scoring):
"""Check the scoring parameter in cases when multiple metrics are allowed

Parameters
----------
estimator : sklearn estimator instance
The estimator for which the scoring will be applied.

scoring : str, callable, list, tuple or dict, default=None
scoring : list, tuple or dict
A single string (see :ref:`scoring_parameter`) or a callable
(see :ref:`scoring`) to evaluate the predictions on the test set.

For evaluating multiple metrics, either give a list of (unique) strings
or a dict with names as keys and callables as values.

NOTE that when using custom scorers, each scorer should return a single
value. Metric functions returning a list/array of values can be wrapped
into multiple scorers that return one value each.

See :ref:`multimetric_grid_search` for an example.

If None the estimator's score method is used.
The return value in that case will be ``{'score': <default_scorer>}``.
If the estimator's score method is not available, a ``TypeError``
is raised.

Returns
-------
scorers_dict : dict
A dict mapping each scorer name to its validated scorer.

is_multimetric : bool
True if scorer is a list/tuple or dict of callables
False if scorer is None/str/callable
"""
if callable(scoring) or scoring is None or isinstance(scoring,
str):
scorers = {"score": check_scoring(estimator, scoring=scoring)}
return scorers, False
else:
err_msg_generic = ("scoring should either be a single string or "
"callable for single metric evaluation or a "
"list/tuple of strings or a dict of scorer name "
"mapped to the callable for multiple metric "
"evaluation. Got %s of type %s"
% (repr(scoring), type(scoring)))

if isinstance(scoring, (list, tuple, set)):
err_msg = ("The list/tuple elements must be unique "
"strings of predefined scorers. ")
invalid = False
try:
keys = set(scoring)
except TypeError:
invalid = True
if invalid:
raise ValueError(err_msg)

if len(keys) != len(scoring):
raise ValueError(err_msg + "Duplicate elements were found in"
" the given list. %r" % repr(scoring))
elif len(keys) > 0:
if not all(isinstance(k, str) for k in keys):
if any(callable(k) for k in keys):
raise ValueError(err_msg +
"One or more of the elements were "
"callables. Use a dict of score name "
"mapped to the scorer callable. "
"Got %r" % repr(scoring))
else:
raise ValueError(err_msg +
"Non-string types were found in "
"the given list. Got %r"
% repr(scoring))
scorers = {scorer: check_scoring(estimator, scoring=scorer)
for scorer in scoring}
else:
raise ValueError(err_msg +
"Empty list was given. %r" % repr(scoring))

elif isinstance(scoring, dict):
err_msg_generic = (
f"scoring is invalid (got {scoring!r}). Refer to the "
"scoring glossary for details: "
"https://scikit-learn.org/stable/glossary.html#term-scoring")

if isinstance(scoring, (list, tuple, set)):
err_msg = ("The list/tuple elements must be unique "
"strings of predefined scorers. ")
invalid = False
try:
keys = set(scoring)
except TypeError:
invalid = True
if invalid:
raise ValueError(err_msg)

if len(keys) != len(scoring):
raise ValueError(f"{err_msg} Duplicate elements were found in"
f" the given list. {scoring!r}")
elif len(keys) > 0:
if not all(isinstance(k, str) for k in keys):
raise ValueError("Non-string types were found in the keys of "
"the given dict. scoring=%r" % repr(scoring))
if len(keys) == 0:
raise ValueError("An empty dict was passed. %r"
% repr(scoring))
scorers = {key: check_scoring(estimator, scoring=scorer)
for key, scorer in scoring.items()}
if any(callable(k) for k in keys):
raise ValueError(f"{err_msg} One or more of the elements "
"were callables. Use a dict of score "
"name mapped to the scorer callable. "
f"Got {scoring!r}")
else:
raise ValueError(f"{err_msg} Non-string types were found "
f"in the given list. Got {scoring!r}")
scorers = {scorer: check_scoring(estimator, scoring=scorer)
for scorer in scoring}
else:
raise ValueError(err_msg_generic)
return scorers, True
raise ValueError(f"{err_msg} Empty list was given. {scoring!r}")

elif isinstance(scoring, dict):
keys = set(scoring)
if not all(isinstance(k, str) for k in keys):
raise ValueError("Non-string types were found in the keys of "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we don't check for this case in the test

f"the given dict. scoring={scoring!r}")
if len(keys) == 0:
raise ValueError(f"An empty dict was passed. {scoring!r}")
scorers = {key: check_scoring(estimator, scoring=scorer)
for key, scorer in scoring.items()}
else:
raise ValueError(err_msg_generic)
return scorers


@_deprecate_positional_args
Expand Down
115 changes: 56 additions & 59 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,69 +206,66 @@ def check_scoring_validator_for_single_metric_usecases(scoring_validator):
assert scorer is None


def check_multimetric_scoring_single_metric_wrapper(*args, **kwargs):
# This wraps the _check_multimetric_scoring to take in
# single metric scoring parameter so we can run the tests
# that we will run for check_scoring, for check_multimetric_scoring
# too for single-metric usecases

scorers, is_multi = _check_multimetric_scoring(*args, **kwargs)
# For all single metric use cases, it should register as not multimetric
assert not is_multi
if args[0] is not None:
assert scorers is not None
names, scorers = zip(*scorers.items())
assert len(scorers) == 1
assert names[0] == 'score'
scorers = scorers[0]
return scorers


def test_check_scoring_and_check_multimetric_scoring():
@pytest.mark.parametrize(
"scoring",
(
('accuracy', ), ['precision'],
{'acc': 'accuracy', 'precision': 'precision'},
('accuracy', 'precision'),
['precision', 'accuracy'],
{'accuracy': make_scorer(accuracy_score),
'precision': make_scorer(precision_score)}
), ids=["single_tuple", "single_list", "dict_str",
"multi_tuple", "multi_list", "dict_callable"])
def test_check_scoring_and_check_multimetric_scoring(scoring):
check_scoring_validator_for_single_metric_usecases(check_scoring)
# To make sure the check_scoring is correctly applied to the constituent
# scorers
check_scoring_validator_for_single_metric_usecases(
check_multimetric_scoring_single_metric_wrapper)

# For multiple metric use cases
# Make sure it works for the valid cases
for scoring in (('accuracy',), ['precision'],
{'acc': 'accuracy', 'precision': 'precision'},
('accuracy', 'precision'), ['precision', 'accuracy'],
{'accuracy': make_scorer(accuracy_score),
'precision': make_scorer(precision_score)}):
estimator = LinearSVC(random_state=0)
estimator.fit([[1], [2], [3]], [1, 1, 0])

scorers, is_multi = _check_multimetric_scoring(estimator, scoring)
assert is_multi
assert isinstance(scorers, dict)
assert sorted(scorers.keys()) == sorted(list(scoring))
assert all([isinstance(scorer, _PredictScorer)
for scorer in list(scorers.values())])

if 'acc' in scoring:
assert_almost_equal(scorers['acc'](
estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.)
if 'accuracy' in scoring:
assert_almost_equal(scorers['accuracy'](
estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.)
if 'precision' in scoring:
assert_almost_equal(scorers['precision'](
estimator, [[1], [2], [3]], [1, 0, 0]), 0.5)

estimator = LinearSVC(random_state=0)
estimator.fit([[1], [2], [3]], [1, 1, 0])

scorers = _check_multimetric_scoring(estimator, scoring)
assert isinstance(scorers, dict)
assert sorted(scorers.keys()) == sorted(list(scoring))
assert all([isinstance(scorer, _PredictScorer)
for scorer in list(scorers.values())])

if 'acc' in scoring:
assert_almost_equal(scorers['acc'](
estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.)
if 'accuracy' in scoring:
assert_almost_equal(scorers['accuracy'](
estimator, [[1], [2], [3]], [1, 0, 0]), 2. / 3.)
if 'precision' in scoring:
assert_almost_equal(scorers['precision'](
estimator, [[1], [2], [3]], [1, 0, 0]), 0.5)


@pytest.mark.parametrize("scoring", [
((make_scorer(precision_score), make_scorer(accuracy_score)),
"One or more of the elements were callables"),
([5], "Non-string types were found"),
((make_scorer(precision_score), ),
"One of mor eof the elements were callables"),
((), "Empty list was given"),
(('f1', 'f1'), "Duplicate elements were found"),
({4: 'accuracy'}, "Non-string types were found in the keys"),
({}, "An empty dict was passed"),
], ids=[
"tuple of callables", "list of int",
"tuple of one callable", "empty tuple",
"non-unique str", "non-string key dict",
"empty dict"])
def test_check_scoring_and_check_multimetric_scoring_errors(scoring):
# Make sure it raises errors when scoring parameter is not valid.
# More weird corner cases are tested at test_validation.py
estimator = EstimatorWithFitAndPredict()
estimator.fit([[1]], [1])

# Make sure it raises errors when scoring parameter is not valid.
# More weird corner cases are tested at test_validation.py
error_message_regexp = ".*must be unique strings.*"
for scoring in ((make_scorer(precision_score), # Tuple of callables
make_scorer(accuracy_score)), [5],
(make_scorer(precision_score),), (), ('f1', 'f1')):
with pytest.raises(ValueError, match=error_message_regexp):
_check_multimetric_scoring(estimator, scoring=scoring)
with pytest.raises(ValueError, match=error_message_regexp):
_check_multimetric_scoring(estimator, scoring=scoring)


def test_check_scoring_gridsearchcv():
Expand Down Expand Up @@ -622,7 +619,7 @@ def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count,
mock_est.predict_proba = predict_proba_func
mock_est.decision_function = decision_function_func

scorer_dict, _ = _check_multimetric_scoring(LogisticRegression(), scorers)
scorer_dict = _check_multimetric_scoring(LogisticRegression(), scorers)
multi_scorer = _MultimetricScorer(**scorer_dict)
results = multi_scorer(mock_est, X, y)

Expand All @@ -649,7 +646,7 @@ def predict_proba(self, X):
clf.fit(X, y)

scorers = ['roc_auc', 'neg_log_loss']
scorer_dict, _ = _check_multimetric_scoring(clf, scorers)
scorer_dict = _check_multimetric_scoring(clf, scorers)
scorer = _MultimetricScorer(**scorer_dict)
scorer(clf, X, y)

Expand All @@ -672,7 +669,7 @@ def predict(self, X):
clf.fit(X, y)

scorers = {'neg_mse': 'neg_mean_squared_error', 'r2': 'roc_auc'}
scorer_dict, _ = _check_multimetric_scoring(clf, scorers)
scorer_dict = _check_multimetric_scoring(clf, scorers)
scorer = _MultimetricScorer(**scorer_dict)
scorer(clf, X, y)

Expand All @@ -690,7 +687,7 @@ def test_multimetric_scorer_sanity_check():
clf = DecisionTreeClassifier()
clf.fit(X, y)

scorer_dict, _ = _check_multimetric_scoring(clf, scorers)
scorer_dict = _check_multimetric_scoring(clf, scorers)
multi_scorer = _MultimetricScorer(**scorer_dict)

result = multi_scorer(clf, X, y)
Expand Down
Loading