|
27 | 27 | from sklearn.utils.testing import ignore_warnings
|
28 | 28 |
|
29 | 29 | from sklearn.metrics import accuracy_score
|
| 30 | +from sklearn.metrics import balanced_accuracy_score |
30 | 31 | from sklearn.metrics import average_precision_score
|
31 | 32 | from sklearn.metrics import classification_report
|
32 | 33 | from sklearn.metrics import confusion_matrix
|
@@ -127,6 +128,43 @@ def test_multilabel_accuracy_score_subset_accuracy():
|
127 | 128 | assert_equal(accuracy_score(y2, [(), ()], normalize=False), 0)
|
128 | 129 |
|
129 | 130 |
|
| 131 | +def test_balanced_accuracy_score(): |
| 132 | + # Test balanced accuracy score for binary classification task |
| 133 | + |
| 134 | + # test on an imbalanced data set |
| 135 | + y_true = np.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1]) |
| 136 | + y_pred = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) |
| 137 | + |
| 138 | + assert_equal(balanced_accuracy_score(y_true, y_pred), 0.5) |
| 139 | + |
| 140 | + # test the function with the equation defined as |
| 141 | + # 0.5 * true positives / (true positives + false negatives) + |
| 142 | + # 0.5 * true negatives / (true negatives + false positives) |
| 143 | + y_true, y_pred, _ = make_prediction(binary=True) |
| 144 | + tn, fp, fn, tp = np.bincount(y_true * 2 + y_pred, minlength=4) |
| 145 | + bas = 0.5 * tp / (tp + fn) + 0.5 * tn / (tn + fp) |
| 146 | + assert_equal(balanced_accuracy_score(y_true, y_pred), bas) |
| 147 | + |
| 148 | + # test using string labels |
| 149 | + y_true = np.array(['a', 'b', 'a', 'b']) |
| 150 | + y_pred = np.array(['a', 'b', 'a', 'a']) |
| 151 | + |
| 152 | + assert_equal(balanced_accuracy_score(y_true, y_pred), 0.75) |
| 153 | + |
| 154 | + |
| 155 | +def test_balanced_accuracy_score_on_non_binary_class(): |
| 156 | + # Test that balanced_accuracy_score returns an error when trying |
| 157 | + # to comptue balanced_accuracy_score for multiclass task. |
| 158 | + rng = check_random_state(404) |
| 159 | + y_pred = rng.randint(0, 3, size=10) |
| 160 | + |
| 161 | + # y_true contains three different class values |
| 162 | + y_true = rng.randint(0, 3, size=10) |
| 163 | + |
| 164 | + assert_raise_message(ValueError, "multiclass is not supported", |
| 165 | + balanced_accuracy_score, y_true, y_pred) |
| 166 | + |
| 167 | + |
130 | 168 | def test_precision_recall_f1_score_binary():
|
131 | 169 | # Test Precision Recall and F1 Score for binary classification task
|
132 | 170 | y_true, y_pred, _ = make_prediction(binary=True)
|
|
0 commit comments