-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Adding support for balanced accuracy #8066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
As discussed in the earlier PR thread, extending it for the multilabel case yields a lot of edges and has several corner cases. So, do we want to implement this for multilabel? |
That might be difficult by simply wrapping In addition, the file |
Yes, I read the discussion on your thread. Since you have already tried implementing it, do you suggest we should try adding it ? |
I would've thought multilabel is easy to do like other metrics; multiclass
was disputed.
…On 22 December 2016 at 16:57, Aman Dalmia ***@***.***> wrote:
Yes, I read the discussion on your thread. Since you have already tried
implementing it, do you suggest we should try adding it ?
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
<#8066 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz65pL7rooje_U0kgEGwYufeqfig_sks5rKhEugaJpZM4LOwgi>
.
|
Conflicts: doc/modules/model_evaluation.rst sklearn/metrics/scorer.py
@jnothman So, I've went over the discussion on this enhancement on the issue thread and the previous pull requests and am presenting a summary to wrap everything up:
So please let me know if you feel we can simply support binary problems or is it critical to try something different for the multilabel case? |
I wrote the balanced_acc as follows (based on the def balanced_accuracy_score(y_true, y_pred, labels=None,
average=None, balance=0.5):
"""Compute the balanced accuracy
The balanced accuracy is used in binary classification problems to deal
with imbalanced datasets. It can also be extend to multilabel problems.
It is defined as the weighted arithmetic mean of sensitivity
(true positive rate, TPR) and specificity (true negative rate, TNR), or
the weighted average recall obtained on either class:
balanced accuracy = balance * TPR + (1 - balance) * TNR
It is also equal to the ROC AUC score for binary inputs when balance is 0.5.
The best value is 1 and the worst value is 0.
Note: this implementation is restricted to binary classification tasks
or multilabel tasks in label indicator format.
Read more in the :ref:`User Guide <balanced_accuracy_score>`.
Parameters
----------
y_true : 1d array-like
Ground truth (correct) target values.
y_pred : 1d array-like
Estimated targets as returned by a classifier.
labels : list, optional
The set of labels to include for multilabel problem, and their
order if ``average is None``. For multilabel targets,
labels are column indices. By default, all labels in ``y_true`` and
``y_pred`` are used in sorted order.
average : string, [None (default), 'micro', 'macro']
If ``None``, the scores for each class are returned. Otherwise,
this determines the type of averaging performed on the data:
``'micro'``:
Calculate metrics globally by considering each element of the label
indicator matrix as a label.
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
balance : float between 0 and 1.
Weight associated with the sensitivity (or recall) against specificity in
final score.
Returns
-------
balanced_accuracy : float.
The average of sensitivity and specificity
See also
--------
recall_score, roc_auc_score
References
----------
.. [1] Brodersen, K.H.; Ong, C.S.; Stephan, K.E.; Buhmann, J.M. (2010).
The balanced accuracy and its posterior distribution.
Proceedings of the 20th International Conference on Pattern Recognition,
3121-24.
Examples
--------
>>> from sklearn.metrics import balanced_accuracy_score
>>> y_true = [0, 1, 0, 0, 1, 0]
>>> y_pred = [0, 1, 0, 0, 0, 1]
>>> balanced_accuracy_score(y_true, y_pred)
0.625
>>> y_true = np.array([[1, 0], [1, 0], [0, 1]])
>>> y_pred = np.array([[1, 1], [0, 1], [1, 1]])
>>> balanced_accuracy_score(y_true, y_pred, average=None)
array([ 0.25, 0.5 ])
"""
# TODO: handle sparse input in multilabel setting
# TODO: ensure `sample_weight`'s shape is consistent with `y_true` and `y_pred`
# TODO: handle situations when only one class presents in `y_true`
# TODO: accept an `labels` argument
average_options = (None, 'micro', 'macro', 'samples')
if average not in average_options:
raise ValueError('average has to be one of ' +
str(average_options))
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
present_labels = unique_labels(y_true, y_pred)
if y_type == 'multiclass':
raise ValueError('Balanced accuracy is only meaningful '
'for binary classification or '
'multilabel problems.')
if labels is None:
labels = present_labels
n_labels = None
else:
n_labels = len(labels)
labels = np.hstack([labels, np.setdiff1d(present_labels, labels,
assume_unique=True)])
# Calculate tp_sum, pred_sum, true_sum ###
if y_type.startswith('multilabel'):
sum_axis = 1 if average == 'samples' else 0
# All labels are index integers for multilabel.
# Select labels:
if not np.all(labels == present_labels):
if np.max(labels) > np.max(present_labels):
raise ValueError('All labels must be in [0, n labels). '
'Got %d > %d' %
(np.max(labels), np.max(present_labels)))
if np.min(labels) < 0:
raise ValueError('All labels must be in [0, n labels). '
'Got %d < 0' % np.min(labels))
y_true = y_true[:, labels[:n_labels]]
y_pred = y_pred[:, labels[:n_labels]]
# indicator matrix for the negative (zero) class
# TODO: Inefficient due to the generation of dense matrices.
y_true_z = np.ones(y_true.shape)
y_true_z[y_true.nonzero()] = 0
y_true_z = csr_matrix(y_true_z)
y_pred_z = np.ones(y_true.shape)
y_pred_z[y_pred.nonzero()] = 0
y_pred_z = csr_matrix(y_pred_z)
# calculate weighted counts for the positive class
true_and_pred_p = y_true.multiply(y_pred)
tp_sum_p = count_nonzero(true_and_pred_p, axis=sum_axis)
true_sum_p = count_nonzero(y_true, axis=sum_axis)
# calculate weighted counts for the negative class
true_and_pred_n = y_true_z.multiply(y_pred_z)
tp_sum_n = count_nonzero(true_and_pred_n, axis=sum_axis)
true_sum_n = count_nonzero(y_true_z, axis=sum_axis)
# the final true positive and positive
tp_sum = np.vstack((tp_sum_p, tp_sum_n))
true_sum = np.vstack((true_sum_p, true_sum_n))
if average == 'micro':
tp_sum = np.array([tp_sum.sum(axis=1)])
true_sum = np.array([true_sum.sum(axis=1)])
elif average == 'samples':
raise ValueError("Sample-based balanced accuracy is "
"not meaningful outside multilabel "
"problems.")
else:
# binary classification case ##
if labels is not None:
warnings.warn("The `labels` argument will be ignored "
"in binary classification problems.")
le = LabelEncoder()
le.fit(labels)
y_true = le.transform(y_true)
y_pred = le.transform(y_pred)
# labels are now either 0 or 1 -> use bincount
tp = y_true == y_pred
tp_bins = y_true[tp]
tp_bins_weights = None
if len(tp_bins):
tp_sum = bincount(tp_bins, weights=tp_bins_weights,
minlength=2)
else:
# Pathological case
true_sum = tp_sum = np.zeros(2)
if len(y_true):
true_sum = bincount(y_true, minlength=2)
# Finally, we have all our sufficient statistics. Divide! #
with np.errstate(divide='ignore', invalid='ignore'):
# Divide, and on zero-division, set scores to 0 and warn:
# Oddly, we may get an "invalid" rather than a "divide" error
# here.
recalls = _prf_divide(tp_sum, true_sum,
'recall', 'true', average, ('recall',))
bacs = np.average(recalls, axis=0)
# Average the results
if average is not None:
bacs = np.average(bacs)
return bacs |
I am okay with simply supporting binary problems. If the multiclass formulation is standard (there are many multiclass ROC formulations), then supporting that makes sense too. |
I may not claim that the multiclass formulation is standard, but I mentioned the formulation above based on this. Please let me know what you think. |
Do we have an opinion on this? |
|
||
References | ||
---------- | ||
.. [1] Brodersen, K.H.; Ong, C.S.; Stephan, K.E.; Buhmann, J.M. (2010). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This paper only treats the binary case and it's not clear to me that it does the same thing as this code. We need more references.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wait, this PR is only for the binary case? hm...
maybe call this metric |
I think balanced accuracy is ordinarily binary. The attempt to extend it to
the multiclass case without clear references is a major reason for the
contribution to stall.
On 22 Jul 2017 3:53 am, "Andreas Mueller" <notifications@github.com> wrote:
maybe call this metric binary_balanced_accuracy?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#8066 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz64GlD64GHv4g87tOvDNMCB_Rd_1Vks5sQOWEgaJpZM4LOwgi>
.
|
Well there are several extensions. I'd say we call this |
@jnothman so do you think it should be called |
I think that would be a very good place to start. It's a useful and
uncontroversial metric in the binary case, but lots of people won't use it
if we don't provide it by that name.
…On 7 Sep 2017 2:49 am, "Andreas Mueller" ***@***.***> wrote:
@jnothman <https://github.com/jnothman> so do you think it should be
called balanced_accuracy and just implement the binary case? I'm also
fine with that. I thought binary_balanced_accuracy might be more explicit
but might also be redundant.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#8066 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz61r2kyOvaRQAKTvazBtcCQThepBBks5sfs0LgaJpZM4LOwgi>
.
|
Yes, I think this PR is good. I don't know why it's labelled WIP. LGTM. (I know we could support multilabel balacc, and we can't just do that with recall_score, but could with roc_auc_score, but I think we should just get this in the library.) |
@amueller, let's do this? |
@jnothman yes. Sorry for the slow reply. Not handling the teacher life well. |
* add function computing balanced accuracy * documentation for the balanced_accuracy_score * apply common tests to balanced_accuracy_score * constrained to binary classification problems only * add balanced_accuracy_score for CLF test * add scorer for balanced_accuracy * reorder the place of importing balanced_accuracy_score to be consistent with others * eliminate an accidentally added non-ascii character * remove balanced_accuracy_score from METRICS_WITH_LABELS * eliminate all non-ascii charaters in the doc of balanced_accuracy_score * fix doctest for nonexistent scoring function * fix documentation, clarify linkages to recall and auc * FIX: added changes as per last review See scikit-learn#6752, fixes scikit-learn#6747 * FIX: fix typo * FIX: remove flake8 errors * DOC: merge fixes * DOC: remove unwanted files * DOC update what's new
* add function computing balanced accuracy * documentation for the balanced_accuracy_score * apply common tests to balanced_accuracy_score * constrained to binary classification problems only * add balanced_accuracy_score for CLF test * add scorer for balanced_accuracy * reorder the place of importing balanced_accuracy_score to be consistent with others * eliminate an accidentally added non-ascii character * remove balanced_accuracy_score from METRICS_WITH_LABELS * eliminate all non-ascii charaters in the doc of balanced_accuracy_score * fix doctest for nonexistent scoring function * fix documentation, clarify linkages to recall and auc * FIX: added changes as per last review See scikit-learn#6752, fixes scikit-learn#6747 * FIX: fix typo * FIX: remove flake8 errors * DOC: merge fixes * DOC: remove unwanted files * DOC update what's new
Reference Issue
Fixes #6747
What does this implement/fix? Explain your changes.
This is a continuation from #6752. @xyguo isn't available to work on this recently and hence, I have taken up from him and want to thank him for his contribution which has greatly helped me understand the issue. As per my changes, I have made the changes suggested in the last review in the PR linked and resolved merge conflicts.