Skip to content

[WIP] ENH create a generator of applicable metrics depending on the target y #17889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9d090da
WIP
thomasjpfan Oct 3, 2019
315c335
ENH Increase compability
thomasjpfan Oct 3, 2019
702cf1b
ENH Refactories _fit_and_score
thomasjpfan Oct 3, 2019
a7d2efb
RFC Moves support into a function
thomasjpfan Oct 4, 2019
c77afd7
BUG Fix old numpy bug
thomasjpfan Oct 4, 2019
5ab8693
TST Removes tests for error on multimetric
thomasjpfan Oct 4, 2019
9c53783
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Dec 3, 2019
8676c04
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Dec 4, 2019
e8f8c9f
DOC Indent
thomasjpfan Dec 4, 2019
5f50a32
CLN Refactors multimetric check
thomasjpfan Dec 4, 2019
ad829e1
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jan 7, 2020
c27f592
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Feb 4, 2020
57c390a
CLN Address comments
thomasjpfan Feb 4, 2020
524fd87
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Feb 5, 2020
1b28907
CLN Simplifies checking
thomasjpfan Feb 5, 2020
2cf9ba8
CLN Simplifies aggregation
thomasjpfan Feb 5, 2020
f336d64
CLN Less code the better
thomasjpfan Feb 5, 2020
a86eaf0
CLN Moves definition closer to usage
thomasjpfan Feb 5, 2020
b1782ae
CLN Update error handling
thomasjpfan Feb 6, 2020
d4782b2
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Feb 6, 2020
c14463c
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan May 24, 2020
762d644
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan May 24, 2020
c5f9b42
REV Less diffs
thomasjpfan May 24, 2020
0e79b59
CLN Address comments
thomasjpfan May 24, 2020
49e8c03
REV
thomasjpfan May 24, 2020
4f6ecd7
STY Flake
thomasjpfan May 24, 2020
4fa5eb6
ENH Fix error
thomasjpfan May 24, 2020
97b1db2
REV Less diffs
thomasjpfan May 25, 2020
286bb86
DOC Adds comments
thomasjpfan May 25, 2020
9799297
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan May 25, 2020
5e1b72b
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jun 6, 2020
fe7dae3
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jun 8, 2020
7c85d54
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jul 8, 2020
5da1571
CLN Removes some state
thomasjpfan Jul 8, 2020
e541de3
CLN Address comments
thomasjpfan Jul 9, 2020
e6116c5
Merge remote-tracking branch 'upstream/master' into multimetric_refactor
thomasjpfan Jul 9, 2020
657ef89
BUG Fix score
thomasjpfan Jul 9, 2020
b0cdc57
CLN Adds to glossary
thomasjpfan Jul 9, 2020
714372f
CLN Uses f-strings
thomasjpfan Jul 9, 2020
346f8e3
ENH create a generator of applicable metrics depending on the target y
glemaitre Jul 10, 2020
6e8be2a
iter
glemaitre Jul 10, 2020
5a3bab1
iter
glemaitre Jul 10, 2020
8732aa4
iter
glemaitre Jul 10, 2020
43668af
PEP8
glemaitre Jul 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1583,10 +1583,10 @@ functions or non-estimator constructors.
in the User Guide.

Where multiple metrics can be evaluated, ``scoring`` may be given
either as a list of unique strings or a dictionary with names as keys
and callables as values. Note that this does *not* specify which score
function is to be maximized, and another parameter such as ``refit``
maybe used for this purpose.
either as a list of unique strings, a dictionary with names as keys and
callables as values or a callable that returns a dictionary. Note that
this does *not* specify which score function is to be maximized, and
another parameter such as ``refit`` maybe used for this purpose.


The ``scoring`` parameter is validated and interpreted using
Expand Down
24 changes: 11 additions & 13 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Using multiple metric evaluation
Scikit-learn also permits evaluation of multiple metrics in ``GridSearchCV``,
``RandomizedSearchCV`` and ``cross_validate``.

There are two ways to specify multiple scoring metrics for the ``scoring``
There are three ways to specify multiple scoring metrics for the ``scoring``
parameter:

- As an iterable of string metrics::
Expand All @@ -261,25 +261,23 @@ parameter:
>>> scoring = {'accuracy': make_scorer(accuracy_score),
... 'prec': 'precision'}

Note that the dict values can either be scorer functions or one of the
predefined metric strings.
Note that the dict values can either be scorer functions or one of the
predefined metric strings.

Currently only those scorer functions that return a single score can be passed
inside the dict. Scorer functions that return multiple values are not
permitted and will require a wrapper to return a single metric::
- As a callable that returns a dictionary of scores::

>>> from sklearn.model_selection import cross_validate
>>> from sklearn.metrics import confusion_matrix
>>> # A sample toy binary classification dataset
>>> X, y = datasets.make_classification(n_classes=2, random_state=0)
>>> svm = LinearSVC(random_state=0)
>>> def tn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 0]
>>> def fp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[0, 1]
>>> def fn(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 0]
>>> def tp(y_true, y_pred): return confusion_matrix(y_true, y_pred)[1, 1]
>>> scoring = {'tp': make_scorer(tp), 'tn': make_scorer(tn),
... 'fp': make_scorer(fp), 'fn': make_scorer(fn)}
>>> cv_results = cross_validate(svm, X, y, cv=5, scoring=scoring)
>>> def confusion_matrix_scorer(clf, X, y):
... y_pred = clf.predict(X)
... cm = confusion_matrix(y, y_pred)
... return {'tn': cm[0, 0], 'fp': cm[0, 1],
... 'fn': cm[1, 0], 'tp': cm[1, 1]}
>>> cv_results = cross_validate(svm, X, y, cv=5,
... scoring=confusion_matrix_scorer)
>>> # Getting the test set true positive scores
>>> print(cv_results['test_tp'])
[10 9 8 7 8]
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from ._scorer import check_scoring
from ._scorer import make_scorer
from ._scorer import SCORERS
from ._scorer import get_applicable_scorers
from ._scorer import get_scorer

from ._plot.roc_curve import plot_roc_curve
Expand Down Expand Up @@ -109,6 +110,7 @@
'f1_score',
'fbeta_score',
'fowlkes_mallows_score',
'get_applicable_scorers',
'get_scorer',
'hamming_loss',
'hinge_loss',
Expand Down
270 changes: 199 additions & 71 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
# Arnaud Joly <arnaud.v.joly@gmail.com>
# License: Simplified BSD

from collections import Counter
from collections import namedtuple
from collections.abc import Iterable
from copy import deepcopy
from inspect import signature
from functools import partial
from collections import Counter

import numpy as np

Expand Down Expand Up @@ -423,99 +426,74 @@ def check_scoring(estimator, scoring=None, *, allow_none=False):
" None. %r was passed" % scoring)


def _check_multimetric_scoring(estimator, scoring=None):
def _check_multimetric_scoring(estimator, scoring):
"""Check the scoring parameter in cases when multiple metrics are allowed

Parameters
----------
estimator : sklearn estimator instance
The estimator for which the scoring will be applied.

scoring : str, callable, list, tuple or dict, default=None
scoring : list, tuple or dict
A single string (see :ref:`scoring_parameter`) or a callable
(see :ref:`scoring`) to evaluate the predictions on the test set.

For evaluating multiple metrics, either give a list of (unique) strings
or a dict with names as keys and callables as values.

NOTE that when using custom scorers, each scorer should return a single
value. Metric functions returning a list/array of values can be wrapped
into multiple scorers that return one value each.

See :ref:`multimetric_grid_search` for an example.

If None the estimator's score method is used.
The return value in that case will be ``{'score': <default_scorer>}``.
If the estimator's score method is not available, a ``TypeError``
is raised.

Returns
-------
scorers_dict : dict
A dict mapping each scorer name to its validated scorer.

is_multimetric : bool
True if scorer is a list/tuple or dict of callables
False if scorer is None/str/callable
"""
if callable(scoring) or scoring is None or isinstance(scoring,
str):
scorers = {"score": check_scoring(estimator, scoring=scoring)}
return scorers, False
else:
err_msg_generic = ("scoring should either be a single string or "
"callable for single metric evaluation or a "
"list/tuple of strings or a dict of scorer name "
"mapped to the callable for multiple metric "
"evaluation. Got %s of type %s"
% (repr(scoring), type(scoring)))

if isinstance(scoring, (list, tuple, set)):
err_msg = ("The list/tuple elements must be unique "
"strings of predefined scorers. ")
invalid = False
try:
keys = set(scoring)
except TypeError:
invalid = True
if invalid:
raise ValueError(err_msg)

if len(keys) != len(scoring):
raise ValueError(err_msg + "Duplicate elements were found in"
" the given list. %r" % repr(scoring))
elif len(keys) > 0:
if not all(isinstance(k, str) for k in keys):
if any(callable(k) for k in keys):
raise ValueError(err_msg +
"One or more of the elements were "
"callables. Use a dict of score name "
"mapped to the scorer callable. "
"Got %r" % repr(scoring))
else:
raise ValueError(err_msg +
"Non-string types were found in "
"the given list. Got %r"
% repr(scoring))
scorers = {scorer: check_scoring(estimator, scoring=scorer)
for scorer in scoring}
else:
raise ValueError(err_msg +
"Empty list was given. %r" % repr(scoring))

elif isinstance(scoring, dict):
err_msg_generic = (
f"scoring is invalid (got {scoring!r}). Refer to the "
"scoring glossary for details: "
"https://scikit-learn.org/stable/glossary.html#term-scoring")

if isinstance(scoring, (list, tuple, set)):
err_msg = ("The list/tuple elements must be unique "
"strings of predefined scorers. ")
invalid = False
try:
keys = set(scoring)
except TypeError:
invalid = True
if invalid:
raise ValueError(err_msg)

if len(keys) != len(scoring):
raise ValueError(f"{err_msg} Duplicate elements were found in"
f" the given list. {scoring!r}")
elif len(keys) > 0:
if not all(isinstance(k, str) for k in keys):
raise ValueError("Non-string types were found in the keys of "
"the given dict. scoring=%r" % repr(scoring))
if len(keys) == 0:
raise ValueError("An empty dict was passed. %r"
% repr(scoring))
scorers = {key: check_scoring(estimator, scoring=scorer)
for key, scorer in scoring.items()}
if any(callable(k) for k in keys):
raise ValueError(f"{err_msg} One or more of the elements "
"were callables. Use a dict of score "
"name mapped to the scorer callable. "
f"Got {scoring!r}")
else:
raise ValueError(f"{err_msg} Non-string types were found "
f"in the given list. Got {scoring!r}")
scorers = {scorer: check_scoring(estimator, scoring=scorer)
for scorer in scoring}
else:
raise ValueError(err_msg_generic)
return scorers, True
raise ValueError(f"{err_msg} Empty list was given. {scoring!r}")

elif isinstance(scoring, dict):
keys = set(scoring)
if not all(isinstance(k, str) for k in keys):
raise ValueError("Non-string types were found in the keys of "
f"the given dict. scoring={scoring!r}")
if len(keys) == 0:
raise ValueError(f"An empty dict was passed. {scoring!r}")
scorers = {key: check_scoring(estimator, scoring=scorer)
for key, scorer in scoring.items()}
else:
raise ValueError(err_msg_generic)
return scorers


@_deprecate_positional_args
Expand Down Expand Up @@ -711,3 +689,153 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
qualified_name = '{0}_{1}'.format(name, average)
SCORERS[qualified_name] = make_scorer(metric, pos_label=None,
average=average)

ScorerProperty = namedtuple(
"ScorerProperty", ["scorer", "target_type_supported"],
)

SCORERS_PROPERTY = dict(
explained_variance=ScorerProperty(
scorer=explained_variance_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
r2=ScorerProperty(
scorer=r2_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
max_error=ScorerProperty(
scorer=max_error_scorer,
target_type_supported=("continuous",),
),
neg_median_absolute_error=ScorerProperty(
scorer=neg_median_absolute_error_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
neg_mean_absolute_error=ScorerProperty(
scorer=neg_mean_absolute_error_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
neg_mean_absolute_percentage_error=ScorerProperty(
scorer=neg_mean_absolute_percentage_error_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
neg_mean_squared_error=ScorerProperty(
scorer=neg_mean_squared_error_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
neg_mean_squared_log_error=ScorerProperty(
scorer=neg_mean_squared_log_error_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
neg_root_mean_squared_error=ScorerProperty(
scorer=neg_root_mean_squared_error_scorer,
target_type_supported=("continuous", "continuous-multioutput"),
),
neg_mean_poisson_deviance=ScorerProperty(
scorer=neg_mean_poisson_deviance_scorer,
target_type_supported=("continuous",),
),
neg_mean_gamma_deviance=ScorerProperty(
scorer=neg_mean_gamma_deviance_scorer,
target_type_supported=("continuous",),
),
accuracy=ScorerProperty(
scorer=accuracy_scorer,
target_type_supported=("binary", "multiclass", "multilabel-indicator"),
),
roc_auc=ScorerProperty(
scorer=roc_auc_scorer,
target_type_supported=("binary", "multilabel-indicator"),
),
roc_auc_ovr=ScorerProperty(
scorer=roc_auc_ovr_scorer,
target_type_supported=("multiclass"),
),
roc_auc_ovo=ScorerProperty(
scorer=roc_auc_ovo_scorer,
target_type_supported=("multiclass"),
),
roc_auc_ovr_weighted=ScorerProperty(
scorer=roc_auc_ovr_weighted_scorer,
target_type_supported=("multiclass"),
),
roc_auc_ovo_weighted=ScorerProperty(
scorer=roc_auc_ovo_weighted_scorer,
target_type_supported=("multiclass"),
),
balanced_accuracy=ScorerProperty(
scorer=balanced_accuracy_scorer,
target_type_supported=("binary", "multiclass"),
),
jaccard=ScorerProperty(
scorer=make_scorer(jaccard_score),
target_type_supported=("binary", "multilabel-indicator"),
),
average_precision=ScorerProperty(
scorer=average_precision_scorer,
target_type_supported=("binary", "multilabel-indicator"),
),
neg_log_loss=ScorerProperty(
scorer=neg_log_loss_scorer,
target_type_supported=("binary", "multiclass"),
),
neg_brier_score=ScorerProperty(
scorer=neg_brier_score_scorer,
target_type_supported=("binary"),
),
)

for name, metric in [('precision', precision_score),
('recall', recall_score), ('f1', f1_score),
('jaccard', jaccard_score)]:
SCORERS_PROPERTY[name] = ScorerProperty(
scorer=make_scorer(metric, average='binary'),
target_type_supported=("binary",),
)
for average in ['macro', 'micro', 'samples', 'weighted']:
qualified_name = f'{name}_{average}'
SCORERS_PROPERTY[qualified_name] = ScorerProperty(
scorer=make_scorer(metric, pos_label=None, average=average),
target_type_supported=("multilabel-indicator"),
)


def get_applicable_scorers(y, **scorers_params):
"""Utility providing scorers to be used on `y`.

This utility creates a dictionary containing the scorers which can be used
on `y`. The dictionary returned can be used directly in a
:class:`~sklearn.model_selection.GridSearchCV`.

Additional parameters taken by the different metrics can be passed as
keyword argument.

Parameters
----------
y : array-like
The target used to infer the metrics which can be used.

**scorers_params
Additional parameters to be passed to the scorers when present in their
signature.

Returns
-------
scorers : dict
A dictionary containing the scorer name as key and a scorer callable as
value.
"""
target_type = type_of_target(y)

scorers = {}
for scorer_name, scorer_property in SCORERS_PROPERTY.items():
if target_type in scorer_property.target_type_supported:
scorers[scorer_name] = deepcopy(scorer_property.scorer)
scorer_sig = signature(scorers[scorer_name]._score_func)
for param_name, param_value in scorers_params.items():
if param_name in scorer_sig.parameters:
scorers[scorer_name]._kwargs[param_name] = param_value

if not scorers:
raise ValueError("No compatible scorer with the target 'y' was found.")
return scorers
Loading