Skip to content

[MRG] Add balanced accuracy score in metrics #5588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ details.
:template: function.rst

metrics.accuracy_score
metrics.balanced_accuracy_score
metrics.auc
metrics.average_precision_score
metrics.brier_score_loss
Expand Down
56 changes: 55 additions & 1 deletion doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Scoring Function Comment
**Classification**
'accuracy' :func:`metrics.accuracy_score`
'average_precision' :func:`metrics.average_precision_score`
'balanced_accuracy_score' :func:`metrics.balanced_accuracy_score` for binary targets
'f1' :func:`metrics.f1_score` for binary targets
'f1_micro' :func:`metrics.f1_score` micro-averaged
'f1_macro' :func:`metrics.f1_score` macro-averaged
Expand Down Expand Up @@ -221,6 +222,7 @@ Some of these are restricted to the binary classification case:
.. autosummary::
:template: function.rst

balanced_accuracy_score
matthews_corrcoef
precision_recall_curve
roc_curve
Expand Down Expand Up @@ -621,7 +623,15 @@ In this context, we can define the notions of precision, recall and F-measure:

.. math::

\text{recall} = \frac{tp}{tp + fn},
\text{recall (also called sensitivity)} = \frac{tp}{tp + fn},

.. math::

\text{specificity} = \frac{tn}{tn + fp},

.. math::

\text{balanced accuracy} = 0.5 * \text{sensitivity} + 0.5 * \text{specificity},

.. math::

Expand All @@ -636,6 +646,12 @@ Here are some small examples in binary classification::
1.0
>>> metrics.recall_score(y_true, y_pred)
0.5
>>> metrics.balanced_accuracy_score(y_true, y_pred)
0.75
>>> metrics.balanced_accuracy_score(y_true, y_pred, balance=1)
0.5
>>> metrics.balanced_accuracy_score(y_true, y_pred, balance=0)
1.0
>>> metrics.f1_score(y_true, y_pred) # doctest: +ELLIPSIS
0.66...
>>> metrics.fbeta_score(y_true, y_pred, beta=0.5) # doctest: +ELLIPSIS
Expand Down Expand Up @@ -862,6 +878,44 @@ method.
The first ``[.9, .1]`` in ``y_pred`` denotes 90% probability that the first
sample has label 0. The log loss is non-negative.

.. _balanced_accuracy_score:

Balanced accuracy score
-----------------------

The :func:`balanced_accuracy_score` function computes the
`Balanced accuracy score (BAC) <http://en.wikipedia.org/wiki/Accuracy_and_precision>`_
for binary classes. Quoting Wikipedia:


"The balanced accuracy avoids inflated performance estimates on
imbalanced datasets. It is defined as the arithmetic mean of sensitivity
and specificity, or the average accuracy obtained on either class."

If :math:`tp`, :math:`tn`, :math:`fp` and :math:`fn` are respectively the
number of true positives, true negatives, false positives and false negatives,
the BAC metric is defined as

.. math::

BAC = \frac{tp}{tp + fn) + \frac{tn}{tn + fp)

Here is a small example illustrating the usage of the :func:`balanced_accuracy_score`
function:

>>> from sklearn.metrics import accuracy_score
>>> from sklearn.metrics import balanced_accuracy_score
>>> y_true = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> y_pred = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> accuracy_score(y_true, y_pred) # doctest: +ELLIPSIS
0.9...
>>> balanced_accuracy_score(y_true, y_pred)
0.5

In this example accuracy is not the metric to use. Its value only
reflect the imbalanced distribution of the dataset.
See `Accuracy paradox <http://en.wikipedia.org/wiki/Accuracy_paradox>`_.

.. _matthews_corrcoef:

Matthews correlation coefficient
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .ranking import roc_curve

from .classification import accuracy_score
from .classification import balanced_accuracy_score
from .classification import classification_report
from .classification import cohen_kappa_score
from .classification import confusion_matrix
Expand Down Expand Up @@ -61,6 +62,7 @@

__all__ = [
'accuracy_score',
'balanced_accuracy_score',
'adjusted_mutual_info_score',
'adjusted_rand_score',
'auc',
Expand Down
74 changes: 74 additions & 0 deletions sklearn/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,80 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
return _weighted_sum(score, sample_weight, normalize)


def balanced_accuracy_score(y_true, y_pred, balance=0.5):
"""Balanced accuracy classification score.

The formula for the balanced accuracy score ::

balanced accuracy = balance * TP/(TP + FP) + (1 - balance) * TN/(TN + FN)

Because it needs true/false negative/positive notion it only
supports binary classification.

The `balance` parameter determines the weight of sensitivity in the combined
score. ``balance -> 1`` lends more weight to sensitiviy, while ``balance -> 0``
favors specificity (``balance = 1`` considers only sensitivity, ``balance = 0``
only specificity).

Read more in the :ref:`User Guide <balanced_accuracy_score>`.

Parameters
----------
y_true : 1d array-like, or label indicator array / sparse matrix
Ground truth (correct) labels.

y_pred : 1d array-like, or label indicator array / sparse matrix
Predicted labels, as returned by a classifier.

balance : float between 0 and 1. Weight associated with the sensitivity
(or recall) against specificty in final score.

Returns
-------
score : float

See also
--------
accuracy_score

References
----------
.. [1] `Wikipedia entry for the accuracy and precision
<http://en.wikipedia.org/wiki/Accuracy_and_precision>`

Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import balanced_accuracy_score
>>> y_pred = [0, 0, 1]
>>> y_true = [0, 1, 1]
>>> balanced_accuracy_score(y_true, y_pred)
0.75

>>> y_pred = ["cat", "cat", "ant"]
>>> y_true = ["cat", "ant", "ant"]
>>> balanced_accuracy_score(y_true, y_pred)
0.75

"""

if balance < 0. or 1. < balance:
raise ValueError("balance has to be between 0 and 1")

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if y_type is not "binary":
raise ValueError("%s is not supported" % y_type)

cm = confusion_matrix(y_true, y_pred)
neg, pos = cm.sum(axis=1, dtype='float')
tn, tp = np.diag(cm)

sensitivity = tp / pos
specificity = tn / neg

return balance * sensitivity + (1 - balance) * specificity


def confusion_matrix(y_true, y_pred, labels=None):
"""Compute confusion matrix to evaluate the accuracy of a classification

Expand Down
39 changes: 39 additions & 0 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sklearn.utils.mocking import MockDataFrame

from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import classification_report
from sklearn.metrics import cohen_kappa_score
Expand Down Expand Up @@ -117,6 +118,44 @@ def test_multilabel_accuracy_score_subset_accuracy():
assert_equal(accuracy_score(y2, np.zeros(y1.shape)), 0)


def test_balanced_accuracy_score_binary():
# Test balanced accuracy score for binary classification task

# with numeric labels
y_pred = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
y_true = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1]

assert_equal(balanced_accuracy_score(y_true,y_pred), 0.5)

# with string labels
y_pred = ["ant", "cat", "ant", "cat", "ant"]
y_true = ["cat", "cat", "ant", "ant", "ant"]

assert_equal(balanced_accuracy_score(y_true,y_pred), 0.5*2/3 + 0.5*1/2)

# with specific balance
y_pred = [0, 0, 1]
y_true = [0, 1, 1]

assert_equal(balanced_accuracy_score(y_true, y_pred, balance=0.75), 0.625)

# with wrong balance
assert_raise_message(ValueError, "balance has to be between 0 and 1",
balanced_accuracy_score, y_true, y_pred, balance=2)
assert_raise_message(ValueError, "balance has to be between 0 and 1",
balanced_accuracy_score, y_true, y_pred, balance=-1)


def test_balanced_accuracy_score_no_binary():
# Test balanced_accuracy_score returns an error when trying to
# compute score for multiclass
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]

assert_raise_message(ValueError, "multiclass is not supported",
balanced_accuracy_score, y_true, y_pred)


def test_precision_recall_f1_score_binary():
# Test Precision Recall and F1 Score for binary classification task
y_true, y_pred, _ = make_prediction(binary=True)
Expand Down
5 changes: 5 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sklearn.utils.testing import ignore_warnings

from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import brier_score_loss
from sklearn.metrics import cohen_kappa_score
Expand Down Expand Up @@ -98,6 +99,7 @@

CLASSIFICATION_METRICS = {
"accuracy_score": accuracy_score,
"balanced_accuracy_score": balanced_accuracy_score,
"unnormalized_accuracy_score": partial(accuracy_score, normalize=False),
"confusion_matrix": confusion_matrix,
"hamming_loss": hamming_loss,
Expand Down Expand Up @@ -207,6 +209,7 @@
"micro_average_precision_score",
"macro_average_precision_score",
"samples_average_precision_score",
"balanced_accuracy_score",

"label_ranking_loss",
"label_ranking_average_precision_score",
Expand Down Expand Up @@ -340,6 +343,7 @@
# Asymmetric with respect to their input arguments y_true and y_pred
# metric(y_true, y_pred) != metric(y_pred, y_true).
NOT_SYMMETRIC_METRICS = [
"balanced_accuracy_score",
"explained_variance_score",
"r2_score",
"confusion_matrix",
Expand All @@ -359,6 +363,7 @@

# No Sample weight support
METRICS_WITHOUT_SAMPLE_WEIGHT = [
"balanced_accuracy_score",
"cohen_kappa_score",
"confusion_matrix",
"median_absolute_error",
Expand Down