Skip to content

[MRG+1] Adding support for balanced accuracy #8066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Oct 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ details.
metrics.accuracy_score
metrics.auc
metrics.average_precision_score
metrics.balanced_accuracy_score
metrics.brier_score_loss
metrics.classification_report
metrics.cohen_kappa_score
Expand Down
52 changes: 51 additions & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Scoring Function
============================== ============================================= ==================================
**Classification**
'accuracy' :func:`metrics.accuracy_score`
'balanced_accuracy' :func:`metrics.balanced_accuracy_score` for binary targets
'average_precision' :func:`metrics.average_precision_score`
'brier_score_loss' :func:`metrics.brier_score_loss`
'f1' :func:`metrics.f1_score` for binary targets
Expand Down Expand Up @@ -103,7 +104,7 @@ Usage examples:
>>> model = svm.SVC()
>>> cross_val_score(model, X, y, scoring='wrong_choice')
Traceback (most recent call last):
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'brier_score_loss', 'completeness_score', 'explained_variance', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']
ValueError: 'wrong_choice' is not a valid scoring value. Valid options are ['accuracy', 'adjusted_mutual_info_score', 'adjusted_rand_score', 'average_precision', 'balanced_accuracy', 'brier_score_loss', 'completeness_score', 'explained_variance', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'fowlkes_mallows_score', 'homogeneity_score', 'mutual_info_score', 'neg_log_loss', 'neg_mean_absolute_error', 'neg_mean_squared_error', 'neg_mean_squared_log_error', 'neg_median_absolute_error', 'normalized_mutual_info_score', 'precision', 'precision_macro', 'precision_micro', 'precision_samples', 'precision_weighted', 'r2', 'recall', 'recall_macro', 'recall_micro', 'recall_samples', 'recall_weighted', 'roc_auc', 'v_measure_score']

.. note::

Expand Down Expand Up @@ -279,6 +280,7 @@ Some of these are restricted to the binary classification case:

precision_recall_curve
roc_curve
balanced_accuracy_score


Others also work in the multiclass case:
Expand Down Expand Up @@ -419,6 +421,54 @@ In the multilabel case with binary label indicators: ::
for an example of accuracy score usage using permutations of
the dataset.

.. _balanced_accuracy_score:

Balanced accuracy score
-----------------------

The :func:`balanced_accuracy_score` function computes the
`balanced accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_, which
avoids inflated performance estimates on imbalanced datasets. It is defined as the
arithmetic mean of `sensitivity <https://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_
(true positive rate) and `specificity <https://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_
(true negative rate), or the average of `recall scores <https://en.wikipedia.org/wiki/Precision_and_recall>`_
obtained on either class.

If the classifier performs equally well on either class, this term reduces to the
conventional accuracy (i.e., the number of correct predictions divided by the total
number of predictions). In contrast, if the conventional accuracy is above chance only
because the classifier takes advantage of an imbalanced test set, then the balanced
accuracy, as appropriate, will drop to 50%.

If :math:`\hat{y}_i\in\{0,1\}` is the predicted value of
the :math:`i`-th sample and :math:`y_i\in\{0,1\}` is the corresponding true value,
then the balanced accuracy is defined as

.. math::

\texttt{balanced-accuracy}(y, \hat{y}) = \frac{1}{2} \left(\frac{\sum_i 1(\hat{y}_i = 1 \land y_i = 1)}{\sum_i 1(y_i = 1)} + \frac{\sum_i 1(\hat{y}_i = 0 \land y_i = 0)}{\sum_i 1(y_i = 0)}\right)

where :math:`1(x)` is the `indicator function <https://en.wikipedia.org/wiki/Indicator_function>`_.

Under this definition, the balanced accuracy coincides with :func:`roc_auc_score`
given binary ``y_true`` and ``y_pred``:

>>> import numpy as np
>>> from sklearn.metrics import balanced_accuracy_score, roc_auc_score
>>> y_true = [0, 1, 0, 0, 1, 0]
>>> y_pred = [0, 1, 0, 0, 0, 1]
>>> balanced_accuracy_score(y_true, y_pred)
0.625
>>> roc_auc_score(y_true, y_pred)
0.625

(but in general, :func:`roc_auc_score` takes as its second argument non-binary scores).

.. note::

Currently this score function is only defined for binary classification problems, you
may need to wrap it by yourself if you want to use it for multilabel problems.

.. _cohen_kappa:

Cohen's kappa
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Classifiers and regressors
- Added :class:`naive_bayes.ComplementNB`, which implements the Complement
Naive Bayes classifier described in Rennie et al. (2003).
By :user:`Michael A. Alcorn <airalcorn2>`.

Model evaluation

- Added the :func:`metrics.balanced_accuracy` metric and a corresponding
``'balanced_accuracy'`` scorer for binary classification.
:issue:`8066` by :user:`xyguo` and :user:`Aman Dalmia <dalmia>`.

Enhancements
............
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .ranking import ndcg_score

from .classification import accuracy_score
from .classification import balanced_accuracy_score
from .classification import classification_report
from .classification import cohen_kappa_score
from .classification import confusion_matrix
Expand Down Expand Up @@ -70,6 +71,7 @@
'adjusted_rand_score',
'auc',
'average_precision_score',
'balanced_accuracy_score',
'calinski_harabaz_score',
'classification_report',
'cluster',
Expand Down
61 changes: 61 additions & 0 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,67 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
return r


def balanced_accuracy_score(y_true, y_pred, sample_weight=None):
"""Compute the balanced accuracy

The balanced accuracy is used in binary classification problems to deal
with imbalanced datasets. It is defined as the arithmetic mean of
sensitivity (true positive rate) and specificity (true negative rate),
or the average recall obtained on either class. It is also equal to the
ROC AUC score given binary inputs.

The best value is 1 and the worst value is 0.

Read more in the :ref:`User Guide <balanced_accuracy_score>`.

Parameters
----------
y_true : 1d array-like
Ground truth (correct) target values.

y_pred : 1d array-like
Estimated targets as returned by a classifier.

sample_weight : array-like of shape = [n_samples], optional
Sample weights.

Returns
-------
balanced_accuracy : float.
The average of sensitivity and specificity

See also
--------
recall_score, roc_auc_score

References
----------
.. [1] Brodersen, K.H.; Ong, C.S.; Stephan, K.E.; Buhmann, J.M. (2010).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This paper only treats the binary case and it's not clear to me that it does the same thing as this code. We need more references.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, this PR is only for the binary case? hm...

The balanced accuracy and its posterior distribution.
Proceedings of the 20th International Conference on Pattern
Recognition, 3121-24.

Examples
--------
>>> from sklearn.metrics import balanced_accuracy_score
>>> y_true = [0, 1, 0, 0, 1, 0]
>>> y_pred = [0, 1, 0, 0, 0, 1]
>>> balanced_accuracy_score(y_true, y_pred)
0.625

"""
y_type, y_true, y_pred = _check_targets(y_true, y_pred)

if y_type != 'binary':
raise ValueError('Balanced accuracy is only meaningful '
'for binary classification problems.')
# simply wrap the ``recall_score`` function
return recall_score(y_true, y_pred,
pos_label=None,
average='macro',
sample_weight=sample_weight)


def classification_report(y_true, y_pred, labels=None, target_names=None,
sample_weight=None, digits=2):
"""Build a text report showing the main classification metrics
Expand Down
4 changes: 3 additions & 1 deletion sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from . import (r2_score, median_absolute_error, mean_absolute_error,
mean_squared_error, mean_squared_log_error, accuracy_score,
f1_score, roc_auc_score, average_precision_score,
precision_score, recall_score, log_loss,
precision_score, recall_score, log_loss, balanced_accuracy_score,
explained_variance_score, brier_score_loss)

from .cluster import adjusted_rand_score
Expand Down Expand Up @@ -500,6 +500,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
# Standard Classification Scores
accuracy_scorer = make_scorer(accuracy_score)
f1_scorer = make_scorer(f1_score)
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)

# Score functions that need decision values
roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
Expand Down Expand Up @@ -543,6 +544,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
mean_absolute_error=mean_absolute_error_scorer,
mean_squared_error=mean_squared_error_scorer,
accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
balanced_accuracy=balanced_accuracy_scorer,
average_precision=average_precision_scorer,
log_loss=log_loss_scorer,
neg_log_loss=neg_log_loss_scorer,
Expand Down
4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sklearn.utils.testing import _named_check

from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import brier_score_loss
from sklearn.metrics import cohen_kappa_score
Expand Down Expand Up @@ -100,6 +101,7 @@

CLASSIFICATION_METRICS = {
"accuracy_score": accuracy_score,
"balanced_accuracy_score": balanced_accuracy_score,
"unnormalized_accuracy_score": partial(accuracy_score, normalize=False),
"confusion_matrix": confusion_matrix,
"hamming_loss": hamming_loss,
Expand Down Expand Up @@ -211,6 +213,7 @@
# Those metrics don't support multiclass inputs
METRIC_UNDEFINED_MULTICLASS = [
"brier_score_loss",
"balanced_accuracy_score",

"roc_auc_score",
"micro_roc_auc",
Expand Down Expand Up @@ -352,6 +355,7 @@
# Asymmetric with respect to their input arguments y_true and y_pred
# metric(y_true, y_pred) != metric(y_pred, y_true).
NOT_SYMMETRIC_METRICS = [
"balanced_accuracy_score",
"explained_variance_score",
"r2_score",
"confusion_matrix",
Expand Down
3 changes: 2 additions & 1 deletion sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
'neg_median_absolute_error', 'mean_absolute_error',
'mean_squared_error', 'median_absolute_error']

CLF_SCORERS = ['accuracy', 'f1', 'f1_weighted', 'f1_macro', 'f1_micro',
CLF_SCORERS = ['accuracy', 'balanced_accuracy',
'f1', 'f1_weighted', 'f1_macro', 'f1_micro',
'roc_auc', 'average_precision', 'precision',
'precision_weighted', 'precision_macro', 'precision_micro',
'recall', 'recall_weighted', 'recall_macro', 'recall_micro',
Expand Down