Skip to content

ENH annotate metrics to simplify populating SCORERS #1774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions sklearn/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@
from ..utils.multiclass import unique_labels


###############################################################################
# Annotations
###############################################################################

# TODO: is there a better name for this, or its opposite, "categorical"?
def needs_threshold(metric):
metric.needs_threshold = True
return metric


def greater_is_better(metric):
metric.greater_is_better = True
return metric


def lesser_is_better(metric):
metric.greater_is_better = False
return metric


###############################################################################
# General utilities
###############################################################################
Expand Down Expand Up @@ -96,6 +116,7 @@ def auc(x, y, reorder=False):
###############################################################################
# Binary classification loss
###############################################################################
@lesser_is_better
def hinge_loss(y_true, pred_decision, pos_label=1, neg_label=-1):
"""Average hinge loss (non-regularized)

Expand Down Expand Up @@ -159,6 +180,8 @@ def hinge_loss(y_true, pred_decision, pos_label=1, neg_label=-1):
###############################################################################
# Binary classification scores
###############################################################################
@needs_threshold
@greater_is_better
def average_precision_score(y_true, y_score):
"""Compute average precision (AP) from prediction scores

Expand Down Expand Up @@ -205,6 +228,8 @@ def average_precision_score(y_true, y_score):
return auc(recall, precision)


@needs_threshold
@greater_is_better
def auc_score(y_true, y_score):
"""Compute Area Under the Curve (AUC) from prediction scores

Expand Down Expand Up @@ -251,6 +276,7 @@ def auc_score(y_true, y_score):
return auc(fpr, tpr, reorder=True)


@greater_is_better
def matthews_corrcoef(y_true, y_pred):
"""Compute the Matthews correlation coefficient (MCC) for binary classes

Expand Down Expand Up @@ -306,6 +332,7 @@ def matthews_corrcoef(y_true, y_pred):
return mcc


@needs_threshold
def precision_recall_curve(y_true, probas_pred):
"""Compute precision-recall pairs for different probability thresholds

Expand Down Expand Up @@ -418,6 +445,7 @@ def precision_recall_curve(y_true, probas_pred):
return precision, recall, thresholds


@needs_threshold
def roc_curve(y_true, y_score, pos_label=None):
"""Compute Receiver operating characteristic (ROC)

Expand Down Expand Up @@ -640,6 +668,7 @@ def confusion_matrix(y_true, y_pred, labels=None):
###############################################################################
# Multiclass loss function
###############################################################################
@lesser_is_better
def zero_one_loss(y_true, y_pred, normalize=True):
"""Zero-one classification loss.

Expand Down Expand Up @@ -697,7 +726,7 @@ def zero_one_loss(y_true, y_pred, normalize=True):
"""
y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True)

if is_multilabel(y_true):
if is_multilabel(y_true):
# Handle mix representation
if type(y_true) != type(y_pred):
labels = unique_labels(y_true, y_pred)
Expand All @@ -712,7 +741,7 @@ def zero_one_loss(y_true, y_pred, normalize=True):
# numpy 1.3 : it is required to perform a unique before setxor1d
# to get unique label in numpy 1.3.
# This is needed in order to handle redundant labels.
# FIXME : check if this can be simplified when 1.3 is removed
# FIXME : check if this can be simplified when 1.3 is removed
loss = np.array([np.size(np.setxor1d(np.unique(pred),
np.unique(true))) > 0
for pred, true in zip(y_pred, y_true)])
Expand All @@ -730,6 +759,7 @@ def zero_one_loss(y_true, y_pred, normalize=True):
"'zero_one_loss' and will be removed in release 0.15."
"Default behavior is changed from 'normalize=False' to "
"'normalize=True'")
@lesser_is_better
def zero_one(y_true, y_pred, normalize=False):
"""Zero-One classification loss

Expand Down Expand Up @@ -771,6 +801,7 @@ def zero_one(y_true, y_pred, normalize=False):
###############################################################################
# Multiclass score functions
###############################################################################
@greater_is_better
def accuracy_score(y_true, y_pred):
"""Accuracy classification score.

Expand Down Expand Up @@ -846,6 +877,7 @@ def accuracy_score(y_true, y_pred):
return np.mean(score)


@greater_is_better
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
"""Compute the F1 score, also known as balanced F-score or F-measure

Expand Down Expand Up @@ -931,6 +963,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
pos_label=pos_label, average=average)


@greater_is_better
def fbeta_score(y_true, y_pred, beta, labels=None, pos_label=1,
average='weighted'):
"""Compute the F-beta score
Expand Down Expand Up @@ -1228,6 +1261,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
return avg_precision, avg_recall, avg_fscore, None


@greater_is_better
def precision_score(y_true, y_pred, labels=None, pos_label=1,
average='weighted'):
"""Compute the precision
Expand Down Expand Up @@ -1311,6 +1345,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1,
return p


@greater_is_better
def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):
"""Compute the recall

Expand Down Expand Up @@ -1393,6 +1428,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='weighted'):

@deprecated("Function zero_one_score has been renamed to "
'accuracy_score'" and will be removed in release 0.15.")
@greater_is_better
def zero_one_score(y_true, y_pred):
"""Zero-one classification score (accuracy)

Expand Down Expand Up @@ -1508,6 +1544,7 @@ class 2 1.00 1.00 1.00 2
###############################################################################
# Multilabel loss function
###############################################################################
@lesser_is_better
def hamming_loss(y_true, y_pred, classes=None):
"""Compute the average Hamming loss.

Expand Down Expand Up @@ -1611,6 +1648,7 @@ def hamming_loss(y_true, y_pred, classes=None):
###############################################################################
# Regression loss functions
###############################################################################
@lesser_is_better
def mean_absolute_error(y_true, y_pred):
"""Mean absolute error regression loss

Expand Down Expand Up @@ -1644,6 +1682,7 @@ def mean_absolute_error(y_true, y_pred):
return np.mean(np.abs(y_pred - y_true))


@lesser_is_better
def mean_squared_error(y_true, y_pred):
"""Mean squared error regression loss

Expand Down Expand Up @@ -1680,6 +1719,7 @@ def mean_squared_error(y_true, y_pred):
###############################################################################
# Regression score functions
###############################################################################
@greater_is_better
def explained_variance_score(y_true, y_pred):
"""Explained variance regression score function

Expand Down Expand Up @@ -1724,6 +1764,7 @@ def explained_variance_score(y_true, y_pred):
return 1 - numerator / denominator


@greater_is_better
def r2_score(y_true, y_pred):
"""R² (coefficient of determination) regression score function.

Expand Down
71 changes: 42 additions & 29 deletions sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ class Scorer(object):
Score function (or loss function) with signature
``score_func(y, y_pred, **kwargs)``.

greater_is_better : boolean, default=True
greater_is_better : boolean, default=score_func.greater_is_better or True
Whether score_func is a score function (default), meaning high is good,
or a loss function, meaning low is good.

needs_threshold : bool, default=False
needs_threshold : bool, default=score_func.needs_threshold or False
Whether score_func takes a continuous decision certainty.
For example ``average_precision`` or the area under the roc curve
can not be computed using predictions alone, but need the output of
Expand All @@ -61,17 +61,35 @@ class Scorer(object):
>>> grid = GridSearchCV(LinearSVC(), param_grid={'C': [1, 10]},
... scoring=ftwo_scorer)
"""
def __init__(self, score_func, greater_is_better=True,
needs_threshold=False, **kwargs):
def __init__(self, score_func, greater_is_better=None,
needs_threshold=None, **kwargs):
self.score_func = score_func
self.kwargs = kwargs

if greater_is_better is None:
greater_is_better = self._get_annotation(
score_func, 'greater_is_better', True)
self.greater_is_better = greater_is_better

if needs_threshold is None:
needs_threshold = self._get_annotation(
score_func, 'needs_threshold', False)
self.needs_threshold = needs_threshold
self.kwargs = kwargs

@staticmethod
def _get_annotation(func, attr, default):
if hasattr(func, attr):
return getattr(func, attr)
while hasattr(func, 'func'):
func = getattr(func, 'func') # unwrap functools.partial
if hasattr(func, attr):
return getattr(func, attr)
return default

def __repr__(self):
kwargs_string = "".join([", %s=%s" % (str(k), str(v))
for k, v in self.kwargs.items()])
return ("Scorer(score_func=%s, greater_is_better=%s, needs_thresholds="
return ("Scorer(score_func=%s, greater_is_better=%s, needs_threshold="
"%s%s)" % (self.score_func.__name__, self.greater_is_better,
self.needs_threshold, kwargs_string))

Expand Down Expand Up @@ -111,26 +129,21 @@ def __call__(self, estimator, X, y):
return self.score_func(y, y_pred, **self.kwargs)


# Standard regression scores
r2_scorer = Scorer(r2_score)
mse_scorer = Scorer(mean_squared_error, greater_is_better=False)

# Standard Classification Scores
accuracy_scorer = Scorer(accuracy_score)
f1_scorer = Scorer(f1_score)

# Score functions that need decision values
auc_scorer = Scorer(auc_score, greater_is_better=True, needs_threshold=True)
average_precision_scorer = Scorer(average_precision_score,
needs_threshold=True)
precision_scorer = Scorer(precision_score)
recall_scorer = Scorer(recall_score)

# Clustering scores
ari_scorer = Scorer(adjusted_rand_score)

SCORERS = dict(r2=r2_scorer, mse=mse_scorer, accuracy=accuracy_scorer,
f1=f1_scorer, roc_auc=auc_scorer,
average_precision=average_precision_scorer,
precision=precision_scorer, recall=recall_scorer,
ari=ari_scorer)
SCORERS = {
name: Scorer(metric)
for name, metric in [
# Regression
('r2', r2_score),
('mse', mean_squared_error),
# Classification
('accuracy', accuracy_score),
('f1', f1_score),
('precision', precision_score),
('recall', recall_score),
# Classification thresholded
('roc_auc', auc_score),
('average_precision', average_precision_score),
# Clustering
('ari', adjusted_rand_score),
]
}
48 changes: 48 additions & 0 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import functools
import pickle

from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_raises

from sklearn.metrics import f1_score, r2_score, auc_score, fbeta_score
from sklearn.metrics.metrics import (needs_threshold, greater_is_better,
lesser_is_better)
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics import SCORERS, Scorer
from sklearn.svm import LinearSVC
Expand All @@ -15,6 +19,50 @@
from sklearn.grid_search import GridSearchCV


def test_scorer_default_params():
"""Test to ensure correct default Scorer parameters"""
metric = lambda test, pred: 1.
scorer = Scorer(metric)
assert_equal(scorer.greater_is_better, True)
assert_equal(scorer.needs_threshold, False)


def test_scorer_annotated_params():
"""Test to ensure metric annotations affect Scorer params"""
metric = needs_threshold(lambda test, pred: 1.)
scorer = Scorer(metric)
assert_equal(scorer.greater_is_better, True)
assert_equal(scorer.needs_threshold, True)

metric = greater_is_better(lambda test, pred: 1.)
scorer = Scorer(metric)
assert_equal(scorer.greater_is_better, True)
assert_equal(scorer.needs_threshold, False)

metric = lesser_is_better(lambda test, pred: 1.)
scorer = Scorer(metric)
assert_equal(scorer.greater_is_better, False)
assert_equal(scorer.needs_threshold, False)


def test_scorer_wrapped_annotated_params():
"""Test to ensure metric annotations are found within functools.partial"""
metric = functools.partial(
lesser_is_better(lambda test, pred, param=5: 1.), param=1)
scorer = Scorer(metric)
assert_equal(scorer.greater_is_better, False)
assert_equal(scorer.needs_threshold, False)
assert_equal(metric, scorer.score_func) # ensure still wrapped


def test_scorer_constructor_params():
"""Test to ensure constructor params to Scorer override those annotated"""
metric = lesser_is_better(lambda test, pred: 1.)
scorer = Scorer(metric, greater_is_better=True, needs_threshold=True)
assert_equal(scorer.greater_is_better, True)
assert_equal(scorer.needs_threshold, True)


def test_classification_scores():
X, y = make_blobs(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
Expand Down