Skip to content

Commit de84a06

Browse files
committed
implement balanced_accuracy_score
1 parent 0117fb5 commit de84a06

File tree

5 files changed

+112
-1
lines changed

5 files changed

+112
-1
lines changed

sklearn/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .ranking import roc_curve
1414

1515
from .classification import accuracy_score
16+
from .classification import balanced_accuracy_score
1617
from .classification import classification_report
1718
from .classification import confusion_matrix
1819
from .classification import f1_score
@@ -59,6 +60,7 @@
5960

6061
__all__ = [
6162
'accuracy_score',
63+
'balanced_accuracy_score',
6264
'adjusted_mutual_info_score',
6365
'adjusted_rand_score',
6466
'auc',

sklearn/metrics/classification.py

+64
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,70 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
186186
return _weighted_sum(score, sample_weight, normalize)
187187

188188

189+
def balanced_accuracy_score(y_true, y_pred):
190+
"""Balanced accuracy score
191+
192+
The balanced accuracy score is defined as
193+
0.5 * true positives / (true positives + false negatives) +
194+
0.5 * true negatives / (true negatives + false positives)
195+
196+
This function is equal to the average of positive label recall
197+
and negative label recall.
198+
199+
Parameters
200+
----------
201+
y_true : array, shape = [n_samples]
202+
Ground truth (correct) target values.
203+
204+
y_pred : array, shape = [n_samples]
205+
Estimated targets as returned by a classifier.
206+
207+
Returns
208+
-------
209+
score : float
210+
return the balanced accuracy score.
211+
212+
The best performance is 1.
213+
214+
References
215+
----
216+
.. [1] `Wikipedia entry for balanced accuracy
217+
http://en.wikipedia.org/wiki/Accuracy_and_precision#In_binary_classification
218+
219+
Examples
220+
--------
221+
>>> from sklearn.metrics import balanced_accuracy_score
222+
>>> y_true = [0, 0, 1, 1]
223+
>>> y_pred = [0, 1, 1, 1]
224+
>>> balanced_accuracy_score(y_true, y_pred) # doctest: +ELLIPSIS
225+
0.75...
226+
227+
>>> y_true = [0, 1, 1, 1, 1]
228+
>>> y_pred = [1, 1, 1, 1, 1]
229+
>>> balanced_accuracy_score(y_true, y_pred) # doctest: +ELLIPSIS
230+
0.5...
231+
232+
>>> y_true = ['b', 'a', 'a', 'a']
233+
>>> y_pred = ['a', 'a', 'b', 'a']
234+
>>> balanced_accuracy_score(y_true, y_pred) # doctest: +ELLIPSIS
235+
0.33...
236+
237+
"""
238+
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
239+
if y_type != 'binary':
240+
raise ValueError("%s is not supported" % y_type)
241+
242+
# Label encoding
243+
lb = LabelBinarizer()
244+
y_true_binary = lb.fit_transform(y_true)
245+
y_pred_binary = lb.transform(y_pred)
246+
247+
pos_recall = recall_score(y_true_binary, y_pred_binary)
248+
neg_recall = recall_score(1 - y_true_binary, 1 - y_pred_binary)
249+
250+
return np.average([pos_recall, neg_recall])
251+
252+
189253
def confusion_matrix(y_true, y_pred, labels=None):
190254
"""Compute confusion matrix to evaluate the accuracy of a classification
191255

sklearn/metrics/metrics.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .ranking import roc_curve
1313

1414
from .classification import accuracy_score
15+
from .classification import balanced_accuracy_score
1516
from .classification import classification_report
1617
from .classification import confusion_matrix
1718
from .classification import f1_score

sklearn/metrics/tests/test_classification.py

+38
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sklearn.utils.testing import ignore_warnings
2828

2929
from sklearn.metrics import accuracy_score
30+
from sklearn.metrics import balanced_accuracy_score
3031
from sklearn.metrics import average_precision_score
3132
from sklearn.metrics import classification_report
3233
from sklearn.metrics import confusion_matrix
@@ -127,6 +128,43 @@ def test_multilabel_accuracy_score_subset_accuracy():
127128
assert_equal(accuracy_score(y2, [(), ()], normalize=False), 0)
128129

129130

131+
def test_balanced_accuracy_score():
132+
# Test balanced accuracy score for binary classification task
133+
134+
# test on an imbalanced data set
135+
y_true = np.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1])
136+
y_pred = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
137+
138+
assert_equal(balanced_accuracy_score(y_true, y_pred), 0.5)
139+
140+
# test the function with the equation defined as
141+
# 0.5 * true positives / (true positives + false negatives) +
142+
# 0.5 * true negatives / (true negatives + false positives)
143+
y_true, y_pred, _ = make_prediction(binary=True)
144+
tn, fp, fn, tp = np.bincount(y_true * 2 + y_pred, minlength=4)
145+
bas = 0.5 * tp / (tp + fn) + 0.5 * tn / (tn + fp)
146+
assert_equal(balanced_accuracy_score(y_true, y_pred), bas)
147+
148+
# test using string labels
149+
y_true = np.array(['a', 'b', 'a', 'b'])
150+
y_pred = np.array(['a', 'b', 'a', 'a'])
151+
152+
assert_equal(balanced_accuracy_score(y_true, y_pred), 0.75)
153+
154+
155+
def test_balanced_accuracy_score_on_non_binary_class():
156+
# Test that balanced_accuracy_score returns an error when trying
157+
# to comptue balanced_accuracy_score for multiclass task.
158+
rng = check_random_state(404)
159+
y_pred = rng.randint(0, 3, size=10)
160+
161+
# y_true contains three different class values
162+
y_true = rng.randint(0, 3, size=10)
163+
164+
assert_raise_message(ValueError, "multiclass is not supported",
165+
balanced_accuracy_score, y_true, y_pred)
166+
167+
130168
def test_precision_recall_f1_score_binary():
131169
# Test Precision Recall and F1 Score for binary classification task
132170
y_true, y_pred, _ = make_prediction(binary=True)

sklearn/metrics/tests/test_common.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.utils.testing import ignore_warnings
2525

2626
from sklearn.metrics import accuracy_score
27+
from sklearn.metrics import balanced_accuracy_score
2728
from sklearn.metrics import average_precision_score
2829
from sklearn.metrics import brier_score_loss
2930
from sklearn.metrics import confusion_matrix
@@ -97,6 +98,7 @@
9798

9899
CLASSIFICATION_METRICS = {
99100
"accuracy_score": accuracy_score,
101+
"balanced_accuracy_score": balanced_accuracy_score,
100102
"unnormalized_accuracy_score": partial(accuracy_score, normalize=False),
101103
"confusion_matrix": confusion_matrix,
102104
"hamming_loss": hamming_loss,
@@ -190,6 +192,7 @@
190192
"samples_precision_score", "samples_recall_score",
191193

192194
# Those metrics don't support multiclass outputs
195+
"balanced_accuracy_score",
193196
"average_precision_score", "weighted_average_precision_score",
194197
"micro_average_precision_score", "macro_average_precision_score",
195198
"samples_average_precision_score",
@@ -331,7 +334,9 @@
331334
"micro_recall_score",
332335

333336
"macro_f0.5_score", "macro_f2_score", "macro_precision_score",
334-
"macro_recall_score", "log_loss", "hinge_loss"
337+
"macro_recall_score", "log_loss", "hinge_loss",
338+
339+
"balanced_accuracy_score",
335340
]
336341

337342

@@ -341,6 +346,7 @@
341346
"hamming_loss",
342347
"matthews_corrcoef_score",
343348
"median_absolute_error",
349+
"balanced_accuracy_score",
344350
]
345351

346352

0 commit comments

Comments
 (0)