Skip to content

[MRG+2] ENH&BUG Add pos_label parameter and fix a bug in average_precision_score #9980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 16, 2018
8 changes: 8 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ Metrics
faster. This avoids some reported freezes and MemoryErrors.
:issue:`11135` by `Joel Nothman`_.

- :func:`metrics.average_precision_score` now supports binary ``y_true``
other than ``{0, 1}`` or ``{-1, 1}`` through ``pos_label`` parameter.
:issue:`9980` by :user:`Hanmin Qin <qinhanmin2014>`.

Linear, kernelized and related models

- Deprecate ``random_state`` parameter in :class:`svm.OneClassSVM` as the
Expand Down Expand Up @@ -612,6 +616,10 @@ Metrics
:func:`metrics.mutual_info_score`.
:issue:`9772` by :user:`Kumar Ashutosh <thechargedneutron>`.

- Fixed a bug where :func:`metrics.average_precision_score` will sometimes return
``nan`` when ``sample_weight`` contains 0.
:issue:`9980` by :user:`Hanmin Qin <qinhanmin2014>`.

- Fixed a bug in :func:`metrics.fowlkes_mallows_score` to avoid integer
overflow. Casted return value of `contingency_matrix` to `int64` and computed
product of square roots rather than square root of product.
Expand Down
27 changes: 20 additions & 7 deletions sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from __future__ import division

import warnings
from functools import partial

import numpy as np
from scipy.sparse import csr_matrix
from scipy.stats import rankdata
Expand Down Expand Up @@ -125,7 +127,7 @@ def auc(x, y, reorder='deprecated'):
return area


def average_precision_score(y_true, y_score, average="macro",
def average_precision_score(y_true, y_score, average="macro", pos_label=1,
sample_weight=None):
"""Compute average precision (AP) from prediction scores

Expand All @@ -150,7 +152,7 @@ def average_precision_score(y_true, y_score, average="macro",
Parameters
----------
y_true : array, shape = [n_samples] or [n_samples, n_classes]
True binary labels (either {0, 1} or {-1, 1}).
True binary labels or binary label indicators.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label indicators -> multilabel indicators ??


y_score : array, shape = [n_samples] or [n_samples, n_classes]
Target scores, can either be probability estimates of the positive
Expand All @@ -173,6 +175,10 @@ def average_precision_score(y_true, y_score, average="macro",
``'samples'``:
Calculate metrics for each instance, and find their average.

pos_label : int or str (default=1)
The label of the positive class. Only applied to binary ``y_true``.
For multilabel-indicator ``y_true``, ``pos_label`` is fixed to 1.

sample_weight : array-like of shape = [n_samples], optional
Sample weights.

Expand Down Expand Up @@ -209,17 +215,23 @@ def average_precision_score(y_true, y_score, average="macro",
are weighted by the change in recall since the last operating point.
"""
def _binary_uninterpolated_average_precision(
y_true, y_score, sample_weight=None):
y_true, y_score, pos_label=1, sample_weight=None):
precision, recall, _ = precision_recall_curve(
y_true, y_score, sample_weight=sample_weight)
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight)
# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
return -np.sum(np.diff(recall) * np.array(precision)[:-1])

return _average_binary_score(_binary_uninterpolated_average_precision,
y_true, y_score, average,
sample_weight=sample_weight)
y_type = type_of_target(y_true)
if y_type == "multilabel-indicator" and pos_label != 1:
raise ValueError("Parameter pos_label is fixed to 1 for "
"multilabel-indicator y_true. Do not set "
"pos_label or set pos_label to 1.")
average_precision = partial(_binary_uninterpolated_average_precision,
pos_label=pos_label)
return _average_binary_score(average_precision, y_true, y_score,
average, sample_weight=sample_weight)


def roc_auc_score(y_true, y_score, average="macro", sample_weight=None,
Expand Down Expand Up @@ -501,6 +513,7 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None,
sample_weight=sample_weight)

precision = tps / (tps + fps)
precision[np.isnan(precision)] = 0
recall = tps / tps[-1]

# stop when full recall attained
Expand Down
21 changes: 13 additions & 8 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"samples_precision_score",
"samples_recall_score",
"coverage_error",

"average_precision_score",
"weighted_average_precision_score",
"micro_average_precision_score",
"macro_average_precision_score",
"samples_average_precision_score",

"label_ranking_loss",
"label_ranking_average_precision_score",
}
Expand All @@ -264,6 +257,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"samples_roc_auc",
"partial_roc_auc",

"average_precision_score",
"weighted_average_precision_score",
"micro_average_precision_score",
"macro_average_precision_score",
"samples_average_precision_score",

# with default average='binary', multiclass is prohibited
"precision_score",
"recall_score",
Expand Down Expand Up @@ -299,6 +298,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):

"precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score",

"average_precision_score",
"weighted_average_precision_score",
"micro_average_precision_score",
"macro_average_precision_score",
"samples_average_precision_score",

# pos_label support deprecated; to be removed in 0.18:
"weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score",
"weighted_precision_score", "weighted_recall_score",
Expand Down Expand Up @@ -667,7 +672,7 @@ def test_thresholded_invariance_string_vs_numbers_labels(name):
err_msg="{0} failed string vs number "
"invariance test".format(name))

measure_with_strobj = metric(y1_str.astype('O'), y2)
measure_with_strobj = metric_str(y1_str.astype('O'), y2)
assert_array_equal(measure_with_number, measure_with_strobj,
err_msg="{0} failed string object vs number "
"invariance test".format(name))
Expand Down
12 changes: 12 additions & 0 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,18 @@ def test_average_precision_constant_values():
assert_equal(average_precision_score(y_true, y_score), .25)


def test_average_precision_score_pos_label_multilabel_indicator():
# Raise an error for multilabel-indicator y_true with
# pos_label other than 1
y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]])
y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]])
erorr_message = ("Parameter pos_label is fixed to 1 for multilabel"
"-indicator y_true. Do not set pos_label or set "
"pos_label to 1.")
assert_raise_message(ValueError, erorr_message, average_precision_score,
y_true, y_pred, pos_label=0)


def test_score_scale_invariance():
# Test that average_precision_score and roc_auc_score are invariant by
# the scaling or shifting of probabilities
Expand Down