diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index d9aa052d740db..c9c9eee31ae95 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -383,6 +383,10 @@ Metrics faster. This avoids some reported freezes and MemoryErrors. :issue:`11135` by `Joel Nothman`_. +- :func:`metrics.average_precision_score` now supports binary ``y_true`` + other than ``{0, 1}`` or ``{-1, 1}`` through ``pos_label`` parameter. + :issue:`9980` by :user:`Hanmin Qin `. + Linear, kernelized and related models - Deprecate ``random_state`` parameter in :class:`svm.OneClassSVM` as the @@ -612,6 +616,10 @@ Metrics :func:`metrics.mutual_info_score`. :issue:`9772` by :user:`Kumar Ashutosh `. +- Fixed a bug where :func:`metrics.average_precision_score` will sometimes return + ``nan`` when ``sample_weight`` contains 0. + :issue:`9980` by :user:`Hanmin Qin `. + - Fixed a bug in :func:`metrics.fowlkes_mallows_score` to avoid integer overflow. Casted return value of `contingency_matrix` to `int64` and computed product of square roots rather than square root of product. diff --git a/sklearn/metrics/ranking.py b/sklearn/metrics/ranking.py index 5039c5f874a5e..fd6e28a20ae0c 100644 --- a/sklearn/metrics/ranking.py +++ b/sklearn/metrics/ranking.py @@ -20,6 +20,8 @@ from __future__ import division import warnings +from functools import partial + import numpy as np from scipy.sparse import csr_matrix from scipy.stats import rankdata @@ -125,7 +127,7 @@ def auc(x, y, reorder='deprecated'): return area -def average_precision_score(y_true, y_score, average="macro", +def average_precision_score(y_true, y_score, average="macro", pos_label=1, sample_weight=None): """Compute average precision (AP) from prediction scores @@ -150,7 +152,7 @@ def average_precision_score(y_true, y_score, average="macro", Parameters ---------- y_true : array, shape = [n_samples] or [n_samples, n_classes] - True binary labels (either {0, 1} or {-1, 1}). + True binary labels or binary label indicators. y_score : array, shape = [n_samples] or [n_samples, n_classes] Target scores, can either be probability estimates of the positive @@ -173,6 +175,10 @@ def average_precision_score(y_true, y_score, average="macro", ``'samples'``: Calculate metrics for each instance, and find their average. + pos_label : int or str (default=1) + The label of the positive class. Only applied to binary ``y_true``. + For multilabel-indicator ``y_true``, ``pos_label`` is fixed to 1. + sample_weight : array-like of shape = [n_samples], optional Sample weights. @@ -209,17 +215,23 @@ def average_precision_score(y_true, y_score, average="macro", are weighted by the change in recall since the last operating point. """ def _binary_uninterpolated_average_precision( - y_true, y_score, sample_weight=None): + y_true, y_score, pos_label=1, sample_weight=None): precision, recall, _ = precision_recall_curve( - y_true, y_score, sample_weight=sample_weight) + y_true, y_score, pos_label=pos_label, sample_weight=sample_weight) # Return the step function integral # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve return -np.sum(np.diff(recall) * np.array(precision)[:-1]) - return _average_binary_score(_binary_uninterpolated_average_precision, - y_true, y_score, average, - sample_weight=sample_weight) + y_type = type_of_target(y_true) + if y_type == "multilabel-indicator" and pos_label != 1: + raise ValueError("Parameter pos_label is fixed to 1 for " + "multilabel-indicator y_true. Do not set " + "pos_label or set pos_label to 1.") + average_precision = partial(_binary_uninterpolated_average_precision, + pos_label=pos_label) + return _average_binary_score(average_precision, y_true, y_score, + average, sample_weight=sample_weight) def roc_auc_score(y_true, y_score, average="macro", sample_weight=None, @@ -501,6 +513,7 @@ def precision_recall_curve(y_true, probas_pred, pos_label=None, sample_weight=sample_weight) precision = tps / (tps + fps) + precision[np.isnan(precision)] = 0 recall = tps / tps[-1] # stop when full recall attained diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index b858868a74545..51ebd00f00ccc 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -241,13 +241,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "samples_precision_score", "samples_recall_score", "coverage_error", - - "average_precision_score", - "weighted_average_precision_score", - "micro_average_precision_score", - "macro_average_precision_score", - "samples_average_precision_score", - "label_ranking_loss", "label_ranking_average_precision_score", } @@ -264,6 +257,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "samples_roc_auc", "partial_roc_auc", + "average_precision_score", + "weighted_average_precision_score", + "micro_average_precision_score", + "macro_average_precision_score", + "samples_average_precision_score", + # with default average='binary', multiclass is prohibited "precision_score", "recall_score", @@ -299,6 +298,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score", + "average_precision_score", + "weighted_average_precision_score", + "micro_average_precision_score", + "macro_average_precision_score", + "samples_average_precision_score", + # pos_label support deprecated; to be removed in 0.18: "weighted_f0.5_score", "weighted_f1_score", "weighted_f2_score", "weighted_precision_score", "weighted_recall_score", @@ -667,7 +672,7 @@ def test_thresholded_invariance_string_vs_numbers_labels(name): err_msg="{0} failed string vs number " "invariance test".format(name)) - measure_with_strobj = metric(y1_str.astype('O'), y2) + measure_with_strobj = metric_str(y1_str.astype('O'), y2) assert_array_equal(measure_with_number, measure_with_strobj, err_msg="{0} failed string object vs number " "invariance test".format(name)) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 5e9a8a0c847ac..d7915eab60973 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -681,6 +681,18 @@ def test_average_precision_constant_values(): assert_equal(average_precision_score(y_true, y_score), .25) +def test_average_precision_score_pos_label_multilabel_indicator(): + # Raise an error for multilabel-indicator y_true with + # pos_label other than 1 + y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]]) + y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]]) + erorr_message = ("Parameter pos_label is fixed to 1 for multilabel" + "-indicator y_true. Do not set pos_label or set " + "pos_label to 1.") + assert_raise_message(ValueError, erorr_message, average_precision_score, + y_true, y_pred, pos_label=0) + + def test_score_scale_invariance(): # Test that average_precision_score and roc_auc_score are invariant by # the scaling or shifting of probabilities