Skip to content

[MRG+1] Adding support for balanced accuracy #8066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Oct 17, 2017
Merged

Conversation

dalmia
Copy link
Contributor

@dalmia dalmia commented Dec 16, 2016

Reference Issue

Fixes #6747

What does this implement/fix? Explain your changes.

This is a continuation from #6752. @xyguo isn't available to work on this recently and hence, I have taken up from him and want to thank him for his contribution which has greatly helped me understand the issue. As per my changes, I have made the changes suggested in the last review in the PR linked and resolved merge conflicts.

@dalmia dalmia changed the title [WIP] Adding support for balanced accuracy [MRG] Adding support for balanced accuracy Dec 21, 2016
@dalmia
Copy link
Contributor Author

dalmia commented Dec 22, 2016

As discussed in the earlier PR thread, extending it for the multilabel case yields a lot of edges and has several corner cases. So, do we want to implement this for multilabel?

@xyguo
Copy link
Contributor

xyguo commented Dec 22, 2016

That might be difficult by simply wrapping recall_score or roc_auc_score. I thought maybe I have to rewrite it from scratch, which resulted in a function similar to precision_recall_fscore_support. But it is just so ugly and remains in a draft version...

In addition, the file test_common.py also needs a lot of modification. Because the definition of balance_accuracy is a bit "impure" and it couldn't pass several tests: for example, some tests assume that if you accept multi-label input, then you should also accept some parameter, while balance_accuracy doesn't.

@dalmia
Copy link
Contributor Author

dalmia commented Dec 22, 2016

Yes, I read the discussion on your thread. Since you have already tried implementing it, do you suggest we should try adding it ?

@jnothman
Copy link
Member

jnothman commented Dec 22, 2016 via email

@dalmia dalmia changed the title [MRG] Adding support for balanced accuracy [WIP] Adding support for balanced accuracy Jan 7, 2017
dalmia added 3 commits January 7, 2017 13:25
Conflicts:
	doc/modules/model_evaluation.rst
	sklearn/metrics/scorer.py
@dalmia
Copy link
Contributor Author

dalmia commented Jan 7, 2017

@jnothman So, I've went over the discussion on this enhancement on the issue thread and the previous pull requests and am presenting a summary to wrap everything up:

  1. We can extend this to multiclass problems by calculating the macro-average over binarized problems.
  2. But the problem lies in extending it to the multilabel setting - roc_auc_curve doesn't support sparse matrix for y_pred in the multilabel case.

So please let me know if you feel we can simply support binary problems or is it critical to try something different for the multilabel case?

@xyguo
Copy link
Contributor

xyguo commented Jan 7, 2017

I wrote the balanced_acc as follows (based on the precision_recall_fscore_support). It do accept sparse input for multi-label cases, but it has to generate a dense matrix internally: Because the balanced accuracy needs to calculate acc on the negative class, while the sparse matrix stores only the positive labels. There should be an space-efficient way for this since all the information of the negative class can be derived from the sparse matrix, but I don't know if it could be implemented compactly.

def balanced_accuracy_score(y_true, y_pred, labels=None,
                            average=None, balance=0.5):
    """Compute the balanced accuracy

    The balanced accuracy is used in binary classification problems to deal
    with imbalanced datasets. It can also be extend to multilabel problems.

    It is defined as the weighted arithmetic mean of sensitivity
    (true positive rate, TPR) and specificity (true negative rate, TNR), or
    the weighted average recall obtained on either class:

    balanced accuracy = balance * TPR + (1 - balance) * TNR

    It is also equal to the ROC AUC score for binary inputs when balance is 0.5.

    The best value is 1 and the worst value is 0.

    Note: this implementation is restricted to binary classification tasks
    or multilabel tasks in label indicator format.

    Read more in the :ref:`User Guide <balanced_accuracy_score>`.

    Parameters
    ----------
    y_true : 1d array-like
        Ground truth (correct) target values.

    y_pred : 1d array-like
        Estimated targets as returned by a classifier.

    labels : list, optional
        The set of labels to include for multilabel problem, and their
        order if ``average is None``. For multilabel targets,
        labels are column indices. By default, all labels in ``y_true`` and
        ``y_pred`` are used in sorted order.

    average : string, [None (default), 'micro', 'macro']
        If ``None``, the scores for each class are returned. Otherwise,
        this determines the type of averaging performed on the data:

        ``'micro'``:
            Calculate metrics globally by considering each element of the label
            indicator matrix as a label.
        ``'macro'``:
            Calculate metrics for each label, and find their unweighted
            mean.  This does not take label imbalance into account.

    balance : float between 0 and 1.
        Weight associated with the sensitivity (or recall) against specificity in
        final score.

    Returns
    -------
    balanced_accuracy : float.
        The average of sensitivity and specificity

    See also
    --------
    recall_score, roc_auc_score

    References
    ----------
    .. [1] Brodersen, K.H.; Ong, C.S.; Stephan, K.E.; Buhmann, J.M. (2010).
           The balanced accuracy and its posterior distribution.
           Proceedings of the 20th International Conference on Pattern Recognition,
           3121-24.

    Examples
    --------
    >>> from sklearn.metrics import balanced_accuracy_score
    >>> y_true = [0, 1, 0, 0, 1, 0]
    >>> y_pred = [0, 1, 0, 0, 0, 1]
    >>> balanced_accuracy_score(y_true, y_pred)
    0.625
    >>> y_true = np.array([[1, 0], [1, 0], [0, 1]])
    >>> y_pred = np.array([[1, 1], [0, 1], [1, 1]])
    >>> balanced_accuracy_score(y_true, y_pred, average=None)
    array([ 0.25,  0.5 ])

    """
    # TODO: handle sparse input in multilabel setting
    # TODO: ensure `sample_weight`'s shape is consistent with `y_true` and `y_pred`
    # TODO: handle situations when only one class presents in `y_true`
    # TODO: accept an `labels` argument
    average_options = (None, 'micro', 'macro', 'samples')
    if average not in average_options:
        raise ValueError('average has to be one of ' +
                         str(average_options))

    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
    present_labels = unique_labels(y_true, y_pred)

    if y_type == 'multiclass':
        raise ValueError('Balanced accuracy is only meaningful '
                         'for binary classification or '
                         'multilabel problems.')
    if labels is None:
        labels = present_labels
        n_labels = None
    else:
        n_labels = len(labels)
        labels = np.hstack([labels, np.setdiff1d(present_labels, labels,
                                                 assume_unique=True)])

    # Calculate tp_sum, pred_sum, true_sum ###

    if y_type.startswith('multilabel'):

        sum_axis = 1 if average == 'samples' else 0

        # All labels are index integers for multilabel.
        # Select labels:
        if not np.all(labels == present_labels):
            if np.max(labels) > np.max(present_labels):
                raise ValueError('All labels must be in [0, n labels). '
                                 'Got %d > %d' %
                                 (np.max(labels), np.max(present_labels)))
            if np.min(labels) < 0:
                raise ValueError('All labels must be in [0, n labels). '
                                 'Got %d < 0' % np.min(labels))

            y_true = y_true[:, labels[:n_labels]]
            y_pred = y_pred[:, labels[:n_labels]]

        # indicator matrix for the negative (zero) class
        # TODO: Inefficient due to the generation of dense matrices.
        y_true_z = np.ones(y_true.shape)
        y_true_z[y_true.nonzero()] = 0
        y_true_z = csr_matrix(y_true_z)
        y_pred_z = np.ones(y_true.shape)
        y_pred_z[y_pred.nonzero()] = 0
        y_pred_z = csr_matrix(y_pred_z)

        # calculate weighted counts for the positive class
        true_and_pred_p = y_true.multiply(y_pred)
        tp_sum_p = count_nonzero(true_and_pred_p, axis=sum_axis)
        true_sum_p = count_nonzero(y_true, axis=sum_axis)

        # calculate weighted counts for the negative class
        true_and_pred_n = y_true_z.multiply(y_pred_z)
        tp_sum_n = count_nonzero(true_and_pred_n, axis=sum_axis)
        true_sum_n = count_nonzero(y_true_z, axis=sum_axis)

        # the final true positive and positive
        tp_sum = np.vstack((tp_sum_p, tp_sum_n))
        true_sum = np.vstack((true_sum_p, true_sum_n))

        if average == 'micro':
            tp_sum = np.array([tp_sum.sum(axis=1)])
            true_sum = np.array([true_sum.sum(axis=1)])

    elif average == 'samples':
        raise ValueError("Sample-based balanced accuracy is "
                         "not meaningful outside multilabel "
                         "problems.")
    else:
        # binary classification case ##
        if labels is not None:
            warnings.warn("The `labels` argument will be ignored "
                          "in binary classification problems.")

        le = LabelEncoder()
        le.fit(labels)
        y_true = le.transform(y_true)
        y_pred = le.transform(y_pred)

        # labels are now either 0 or 1 -> use bincount
        tp = y_true == y_pred
        tp_bins = y_true[tp]
        tp_bins_weights = None

        if len(tp_bins):
            tp_sum = bincount(tp_bins, weights=tp_bins_weights,
                              minlength=2)
        else:
            # Pathological case
            true_sum = tp_sum = np.zeros(2)
        if len(y_true):
            true_sum = bincount(y_true, minlength=2)

    # Finally, we have all our sufficient statistics. Divide! #

    with np.errstate(divide='ignore', invalid='ignore'):
        # Divide, and on zero-division, set scores to 0 and warn:

        # Oddly, we may get an "invalid" rather than a "divide" error
        # here.
        recalls = _prf_divide(tp_sum, true_sum,
                              'recall', 'true', average, ('recall',))
        bacs = np.average(recalls, axis=0)

    # Average the results
    if average is not None:
        bacs = np.average(bacs)

    return bacs

@jnothman
Copy link
Member

jnothman commented Jan 7, 2017

I am okay with simply supporting binary problems. If the multiclass formulation is standard (there are many multiclass ROC formulations), then supporting that makes sense too.

@dalmia
Copy link
Contributor Author

dalmia commented Jan 8, 2017

I may not claim that the multiclass formulation is standard, but I mentioned the formulation above based on this. Please let me know what you think.

@dalmia
Copy link
Contributor Author

dalmia commented Jan 10, 2017

Do we have an opinion on this?


References
----------
.. [1] Brodersen, K.H.; Ong, C.S.; Stephan, K.E.; Buhmann, J.M. (2010).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This paper only treats the binary case and it's not clear to me that it does the same thing as this code. We need more references.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, this PR is only for the binary case? hm...

@amueller
Copy link
Member

maybe call this metric binary_balanced_accuracy?

@jnothman
Copy link
Member

jnothman commented Jul 22, 2017 via email

@amueller
Copy link
Member

Well there are several extensions. I'd say we call this binary_balanced_accuracy and just do that case for now.

@amueller
Copy link
Member

amueller commented Sep 6, 2017

@jnothman so do you think it should be called balanced_accuracy and just implement the binary case? I'm also fine with that. I thought binary_balanced_accuracy might be more explicit but might also be redundant.

@jnothman
Copy link
Member

jnothman commented Sep 6, 2017 via email

@maskani-moh
Copy link
Contributor

@amueller @jnothman

What's actually left to do in this PR?
It seems like you've agreed to let the multiclass implementation for later.

Shall I then just change the function name from balanced_accuracy to binary_balanced_accuracy?

@jnothman
Copy link
Member

jnothman commented Oct 9, 2017

Yes, I think this PR is good. I don't know why it's labelled WIP. LGTM.

(I know we could support multilabel balacc, and we can't just do that with recall_score, but could with roc_auc_score, but I think we should just get this in the library.)

@jnothman jnothman changed the title [WIP] Adding support for balanced accuracy [MRG+1] Adding support for balanced accuracy Oct 9, 2017
@jnothman
Copy link
Member

jnothman commented Oct 9, 2017

@amueller, let's do this?

@amueller
Copy link
Member

@jnothman yes. Sorry for the slow reply. Not handling the teacher life well.

@amueller amueller merged commit 8daad06 into scikit-learn:master Oct 17, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
* add function computing balanced accuracy

* documentation for the balanced_accuracy_score

* apply common tests to balanced_accuracy_score

* constrained to binary classification problems only

* add balanced_accuracy_score for CLF test

* add scorer for balanced_accuracy

* reorder the place of importing balanced_accuracy_score to be consistent with others

* eliminate an accidentally added non-ascii character

* remove balanced_accuracy_score from METRICS_WITH_LABELS

* eliminate all non-ascii charaters in the doc of balanced_accuracy_score

* fix doctest for nonexistent scoring function

* fix documentation, clarify linkages to recall and auc

* FIX: added changes as per last review See scikit-learn#6752, fixes scikit-learn#6747

* FIX: fix typo

* FIX: remove flake8 errors

* DOC: merge fixes

* DOC: remove unwanted files

* DOC update what's new
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
* add function computing balanced accuracy

* documentation for the balanced_accuracy_score

* apply common tests to balanced_accuracy_score

* constrained to binary classification problems only

* add balanced_accuracy_score for CLF test

* add scorer for balanced_accuracy

* reorder the place of importing balanced_accuracy_score to be consistent with others

* eliminate an accidentally added non-ascii character

* remove balanced_accuracy_score from METRICS_WITH_LABELS

* eliminate all non-ascii charaters in the doc of balanced_accuracy_score

* fix doctest for nonexistent scoring function

* fix documentation, clarify linkages to recall and auc

* FIX: added changes as per last review See scikit-learn#6752, fixes scikit-learn#6747

* FIX: fix typo

* FIX: remove flake8 errors

* DOC: merge fixes

* DOC: remove unwanted files

* DOC update what's new
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add balanced accuracy metric
5 participants