diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 35b82e16f0bca..6cea0b62bee5e 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -1664,6 +1664,67 @@ Here is a small example of usage of this function:: * Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and knowledge discovery handbook (pp. 667-685). Springer US. +.. _ndcg: + +Normalized Discounted Cumulative Gain +------------------------------------- + +Discounted Cumulative Gain (DCG) and Normalized Discounted Cumulative Gain +(NDCG) are ranking metrics; they compare a predicted order to ground-truth +scores, such as the relevance of answers to a query. + +from the Wikipedia page for Discounted Cumulative Gain: + +"Discounted cumulative gain (DCG) is a measure of ranking quality. In +information retrieval, it is often used to measure effectiveness of web search +engine algorithms or related applications. Using a graded relevance scale of +documents in a search-engine result set, DCG measures the usefulness, or gain, +of a document based on its position in the result list. The gain is accumulated +from the top of the result list to the bottom, with the gain of each result +discounted at lower ranks" + +DCG orders the true targets (e.g. relevance of query answers) in the predicted +order, then multiplies them by a logarithmic decay and sums the result. The sum +can be truncated after the first :math:`K` results, in which case we call it +DCG@K. +NDCG, or NDCG@K is DCG divided by the DCG obtained by a perfect prediction, so +that it is always between 0 and 1. Usually, NDCG is preferred to DCG. + +Compared with the ranking loss, NDCG can take into account relevance scores, +rather than a ground-truth ranking. So if the ground-truth consists only of an +ordering, the ranking loss should be preferred; if the ground-truth consists of +actual usefulness scores (e.g. 0 for irrelevant, 1 for relevant, 2 for very +relevant), NDCG can be used. + +For one sample, given the vector of continuous ground-truth values for each +target :math:`y \in \mathbb{R}^{M}`, where :math:`M` is the number of outputs, and +the prediction :math:`\hat{y}`, which induces the ranking funtion :math:`f`, the +DCG score is + +.. math:: + \sum_{r=1}^{\min(K, M)}\frac{y_{f(r)}}{\log(1 + r)} + +and the NDCG score is the DCG score divided by the DCG score obtained for +:math:`y`. + +.. topic:: References: + + * Wikipedia entry for Discounted Cumulative Gain: + https://en.wikipedia.org/wiki/Discounted_cumulative_gain + + * Jarvelin, K., & Kekalainen, J. (2002). + Cumulated gain-based evaluation of IR techniques. ACM Transactions on + Information Systems (TOIS), 20(4), 422-446. + + * Wang, Y., Wang, L., Li, Y., He, D., Chen, W., & Liu, T. Y. (2013, May). + A theoretical analysis of NDCG ranking measures. In Proceedings of the 26th + Annual Conference on Learning Theory (COLT 2013) + + * McSherry, F., & Najork, M. (2008, March). Computing information retrieval + performance measures efficiently in the presence of tied scores. In + European conference on information retrieval (pp. 414-421). Springer, + Berlin, Heidelberg. + .. _regression_metrics: Regression metrics diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 188e52a27d925..9d02080b0afe9 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -200,6 +200,11 @@ Changelog :mod:`sklearn.metrics` ...................... +- |Feature| New ranking metrics :func:`metrics.ndcg_score` and + :func:`metrics.dcg_score` have been added to compute Discounted Cumulative + Gain and Normalized Discounted Cumulative Gain. :pr:`9951` by :user:`Jérôme + Dockès `. + - |MajorFeature| :func:`metrics.plot_roc_curve` has been added to plot roc curves. This function introduces the visualization API described in the :ref:`User Guide `. :pr:`14357` by `Thomas Fan`_. diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index f176669b27fd8..d0b65ad1f4cfa 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -7,8 +7,10 @@ from .ranking import auc from .ranking import average_precision_score from .ranking import coverage_error +from .ranking import dcg_score from .ranking import label_ranking_average_precision_score from .ranking import label_ranking_loss +from .ranking import ndcg_score from .ranking import precision_recall_curve from .ranking import roc_auc_score from .ranking import roc_curve @@ -95,6 +97,7 @@ 'confusion_matrix', 'consensus_score', 'coverage_error', + 'dcg_score', 'davies_bouldin_score', 'euclidean_distances', 'explained_variance_score', @@ -123,6 +126,7 @@ 'median_absolute_error', 'multilabel_confusion_matrix', 'mutual_info_score', + 'ndcg_score', 'normalized_mutual_info_score', 'pairwise_distances', 'pairwise_distances_argmin', diff --git a/sklearn/metrics/ranking.py b/sklearn/metrics/ranking.py index cf0396a2bbfe6..5c88072b395d5 100644 --- a/sklearn/metrics/ranking.py +++ b/sklearn/metrics/ranking.py @@ -1012,3 +1012,381 @@ def label_ranking_loss(y_true, y_score, sample_weight=None): loss[np.logical_or(n_positives == 0, n_positives == n_labels)] = 0. return np.average(loss, weights=sample_weight) + + +def _dcg_sample_scores(y_true, y_score, k=None, + log_base=2, ignore_ties=False): + """Compute Discounted Cumulative Gain. + + Sum the true scores ranked in the order induced by the predicted scores, + after applying a logarithmic discount. + + This ranking metric yields a high value if true labels are ranked high by + ``y_score``. + + Parameters + ---------- + y_true : ndarray, shape (n_samples, n_labels) + True targets of multilabel classification, or true scores of entities + to be ranked. + + y_score : ndarray, shape (n_samples, n_labels) + Target scores, can either be probability estimates, confidence values, + or non-thresholded measure of decisions (as returned by + "decision_function" on some classifiers). + + k : int, optional (default=None) + Only consider the highest k scores in the ranking. If None, use all + outputs. + + log_base : float, optional (default=2) + Base of the logarithm used for the discount. A low value means a + sharper discount (top results are more important). + + ignore_ties : bool, optional (default=False) + Assume that there are no ties in y_score (which is likely to be the + case if y_score is continuous) for efficiency gains. + + Returns + ------- + discounted_cumulative_gain : ndarray, shape (n_samples,) + The DCG score for each sample. + + See also + -------- + ndcg_score : + The Discounted Cumulative Gain divided by the Ideal Discounted + Cumulative Gain (the DCG obtained for a perfect ranking), in order to + have a score between 0 and 1. + + """ + discount = 1 / (np.log(np.arange(y_true.shape[1]) + 2) / np.log(log_base)) + if k is not None: + discount[k:] = 0 + if ignore_ties: + ranking = np.argsort(y_score)[:, ::-1] + ranked = y_true[np.arange(ranking.shape[0])[:, np.newaxis], ranking] + cumulative_gains = discount.dot(ranked.T) + else: + discount_cumsum = np.cumsum(discount) + cumulative_gains = [_tie_averaged_dcg(y_t, y_s, discount_cumsum) + for y_t, y_s in zip(y_true, y_score)] + cumulative_gains = np.asarray(cumulative_gains) + return cumulative_gains + + +def _tie_averaged_dcg(y_true, y_score, discount_cumsum): + """ + Compute DCG by averaging over possible permutations of ties. + + The gain (`y_true`) of an index falling inside a tied group (in the order + induced by `y_score`) is replaced by the average gain within this group. + The discounted gain for a tied group is then the average `y_true` within + this group times the sum of discounts of the corresponding ranks. + + This amounts to averaging scores for all possible orderings of the tied + groups. + + (note in the case of dcg@k the discount is 0 after index k) + + Parameters + ---------- + y_true : ndarray + The true relevance scores + + y_score : ndarray + Predicted scores + + discount_cumsum : ndarray + Precomputed cumulative sum of the discounts. + + Returns + ------- + The discounted cumulative gain. + + References + ---------- + McSherry, F., & Najork, M. (2008, March). Computing information retrieval + performance measures efficiently in the presence of tied scores. In + European conference on information retrieval (pp. 414-421). Springer, + Berlin, Heidelberg. + + """ + _, inv, counts = np.unique( + - y_score, return_inverse=True, return_counts=True) + ranked = np.zeros(len(counts)) + np.add.at(ranked, inv, y_true) + ranked /= counts + groups = np.cumsum(counts) - 1 + discount_sums = np.empty(len(counts)) + discount_sums[0] = discount_cumsum[groups[0]] + discount_sums[1:] = np.diff(discount_cumsum[groups]) + return (ranked * discount_sums).sum() + + +def _check_dcg_target_type(y_true): + y_type = type_of_target(y_true) + supported_fmt = ("multilabel-indicator", "continuous-multioutput", + "multiclass-multioutput") + if y_type not in supported_fmt: + raise ValueError( + "Only {} formats are supported. Got {} instead".format( + supported_fmt, y_type)) + + +def dcg_score(y_true, y_score, k=None, + log_base=2, sample_weight=None, ignore_ties=False): + """Compute Discounted Cumulative Gain. + + Sum the true scores ranked in the order induced by the predicted scores, + after applying a logarithmic discount. + + This ranking metric yields a high value if true labels are ranked high by + ``y_score``. + + Usually the Normalized Discounted Cumulative Gain (NDCG, computed by + ndcg_score) is preferred. + + Parameters + ---------- + y_true : ndarray, shape (n_samples, n_labels) + True targets of multilabel classification, or true scores of entities + to be ranked. + + y_score : ndarray, shape (n_samples, n_labels) + Target scores, can either be probability estimates, confidence values, + or non-thresholded measure of decisions (as returned by + "decision_function" on some classifiers). + + k : int, optional (default=None) + Only consider the highest k scores in the ranking. If None, use all + outputs. + + log_base : float, optional (default=2) + Base of the logarithm used for the discount. A low value means a + sharper discount (top results are more important). + + sample_weight : ndarray, shape (n_samples,), optional (default=None) + Sample weights. If None, all samples are given the same weight. + + ignore_ties : bool, optional (default=False) + Assume that there are no ties in y_score (which is likely to be the + case if y_score is continuous) for efficiency gains. + + Returns + ------- + discounted_cumulative_gain : float + The averaged sample DCG scores. + + See also + -------- + ndcg_score : + The Discounted Cumulative Gain divided by the Ideal Discounted + Cumulative Gain (the DCG obtained for a perfect ranking), in order to + have a score between 0 and 1. + + References + ---------- + `Wikipedia entry for Discounted Cumulative Gain + `_ + + Jarvelin, K., & Kekalainen, J. (2002). + Cumulated gain-based evaluation of IR techniques. ACM Transactions on + Information Systems (TOIS), 20(4), 422-446. + + Wang, Y., Wang, L., Li, Y., He, D., Chen, W., & Liu, T. Y. (2013, May). + A theoretical analysis of NDCG ranking measures. In Proceedings of the 26th + Annual Conference on Learning Theory (COLT 2013) + + McSherry, F., & Najork, M. (2008, March). Computing information retrieval + performance measures efficiently in the presence of tied scores. In + European conference on information retrieval (pp. 414-421). Springer, + Berlin, Heidelberg. + + Examples + -------- + >>> from sklearn.metrics import dcg_score + >>> # we have groud-truth relevance of some answers to a query: + >>> true_relevance = np.asarray([[10, 0, 0, 1, 5]]) + >>> # we predict scores for the answers + >>> scores = np.asarray([[.1, .2, .3, 4, 70]]) + >>> dcg_score(true_relevance, scores) # doctest: +ELLIPSIS + 9.49... + >>> # we can set k to truncate the sum; only top k answers contribute + >>> dcg_score(true_relevance, scores, k=2) # doctest: +ELLIPSIS + 5.63... + >>> # now we have some ties in our prediction + >>> scores = np.asarray([[1, 0, 0, 0, 1]]) + >>> # by default ties are averaged, so here we get the average true + >>> # relevance of our top predictions: (10 + 5) / 2 = 7.5 + >>> dcg_score(true_relevance, scores, k=1) # doctest: +ELLIPSIS + 7.5 + >>> # we can choose to ignore ties for faster results, but only + >>> # if we know there aren't ties in our scores, otherwise we get + >>> # wrong results: + >>> dcg_score(true_relevance, + ... scores, k=1, ignore_ties=True) # doctest: +ELLIPSIS + 5.0 + + """ + y_true = check_array(y_true, ensure_2d=False) + y_score = check_array(y_score, ensure_2d=False) + check_consistent_length(y_true, y_score, sample_weight) + _check_dcg_target_type(y_true) + return np.average( + _dcg_sample_scores( + y_true, y_score, k=k, log_base=log_base, + ignore_ties=ignore_ties), + weights=sample_weight) + + +def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False): + """Compute Normalized Discounted Cumulative Gain. + + Sum the true scores ranked in the order induced by the predicted scores, + after applying a logarithmic discount. Then divide by the best possible + score (Ideal DCG, obtained for a perfect ranking) to obtain a score between + 0 and 1. + + This ranking metric yields a high value if true labels are ranked high by + ``y_score``. + + Parameters + ---------- + y_true : ndarray, shape (n_samples, n_labels) + True targets of multilabel classification, or true scores of entities + to be ranked. + + y_score : ndarray, shape (n_samples, n_labels) + Target scores, can either be probability estimates, confidence values, + or non-thresholded measure of decisions (as returned by + "decision_function" on some classifiers). + + k : int, optional (default=None) + Only consider the highest k scores in the ranking. If None, use all + outputs. + + ignore_ties : bool, optional (default=False) + Assume that there are no ties in y_score (which is likely to be the + case if y_score is continuous) for efficiency gains. + + Returns + ------- + normalized_discounted_cumulative_gain : ndarray, shape (n_samples,) + The NDCG score for each sample (float in [0., 1.]). + + See also + -------- + dcg_score : Discounted Cumulative Gain (not normalized). + + """ + gain = _dcg_sample_scores(y_true, y_score, k, ignore_ties=ignore_ties) + # Here we use the order induced by y_true so we can ignore ties since + # the gain associated to tied indices is the same (permuting ties doesn't + # change the value of the re-ordered y_true) + normalizing_gain = _dcg_sample_scores(y_true, y_true, k, ignore_ties=True) + all_irrelevant = normalizing_gain == 0 + gain[all_irrelevant] = 0 + gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant] + return gain + + +def ndcg_score(y_true, y_score, k=None, sample_weight=None, ignore_ties=False): + """Compute Normalized Discounted Cumulative Gain. + + Sum the true scores ranked in the order induced by the predicted scores, + after applying a logarithmic discount. Then divide by the best possible + score (Ideal DCG, obtained for a perfect ranking) to obtain a score between + 0 and 1. + + This ranking metric yields a high value if true labels are ranked high by + ``y_score``. + + Parameters + ---------- + y_true : ndarray, shape (n_samples, n_labels) + True targets of multilabel classification, or true scores of entities + to be ranked. + + y_score : ndarray, shape (n_samples, n_labels) + Target scores, can either be probability estimates, confidence values, + or non-thresholded measure of decisions (as returned by + "decision_function" on some classifiers). + + k : int, optional (default=None) + Only consider the highest k scores in the ranking. If None, use all + outputs. + + sample_weight : ndarray, shape (n_samples,), optional (default=None) + Sample weights. If None, all samples are given the same weight. + + ignore_ties : bool, optional (default=False) + Assume that there are no ties in y_score (which is likely to be the + case if y_score is continuous) for efficiency gains. + + Returns + ------- + normalized_discounted_cumulative_gain : float in [0., 1.] + The averaged NDCG scores for all samples. + + See also + -------- + dcg_score : Discounted Cumulative Gain (not normalized). + + References + ---------- + `Wikipedia entry for Discounted Cumulative Gain + `_ + + Jarvelin, K., & Kekalainen, J. (2002). + Cumulated gain-based evaluation of IR techniques. ACM Transactions on + Information Systems (TOIS), 20(4), 422-446. + + Wang, Y., Wang, L., Li, Y., He, D., Chen, W., & Liu, T. Y. (2013, May). + A theoretical analysis of NDCG ranking measures. In Proceedings of the 26th + Annual Conference on Learning Theory (COLT 2013) + + McSherry, F., & Najork, M. (2008, March). Computing information retrieval + performance measures efficiently in the presence of tied scores. In + European conference on information retrieval (pp. 414-421). Springer, + Berlin, Heidelberg. + + Examples + -------- + >>> from sklearn.metrics import ndcg_score + >>> # we have groud-truth relevance of some answers to a query: + >>> true_relevance = np.asarray([[10, 0, 0, 1, 5]]) + >>> # we predict some scores (relevance) for the answers + >>> scores = np.asarray([[.1, .2, .3, 4, 70]]) + >>> ndcg_score(true_relevance, scores) # doctest: +ELLIPSIS + 0.69... + >>> scores = np.asarray([[.05, 1.1, 1., .5, .0]]) + >>> ndcg_score(true_relevance, scores) # doctest: +ELLIPSIS + 0.49... + >>> # we can set k to truncate the sum; only top k answers contribute. + >>> ndcg_score(true_relevance, scores, k=4) # doctest: +ELLIPSIS + 0.35... + >>> # the normalization takes k into account so a perfect answer + >>> # would still get 1.0 + >>> ndcg_score(true_relevance, true_relevance, k=4) # doctest: +ELLIPSIS + 1.0 + >>> # now we have some ties in our prediction + >>> scores = np.asarray([[1, 0, 0, 0, 1]]) + >>> # by default ties are averaged, so here we get the average (normalized) + >>> # true relevance of our top predictions: (10 / 10 + 5 / 10) / 2 = .75 + >>> ndcg_score(true_relevance, scores, k=1) # doctest: +ELLIPSIS + 0.75 + >>> # we can choose to ignore ties for faster results, but only + >>> # if we know there aren't ties in our scores, otherwise we get + >>> # wrong results: + >>> ndcg_score(true_relevance, + ... scores, k=1, ignore_ties=True) # doctest: +ELLIPSIS + 0.5 + + """ + y_true = check_array(y_true, ensure_2d=False) + y_score = check_array(y_score, ensure_2d=False) + check_consistent_length(y_true, y_score, sample_weight) + _check_dcg_target_type(y_true) + gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties) + return np.average(gain, weights=sample_weight) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 8d62caa8a16c6..6459f93c68449 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -56,6 +56,8 @@ from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_curve from sklearn.metrics import zero_one_loss +from sklearn.metrics import ndcg_score +from sklearn.metrics import dcg_score from sklearn.metrics.base import _average_binary_score @@ -237,6 +239,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): partial(average_precision_score, average="micro"), "label_ranking_average_precision_score": label_ranking_average_precision_score, + "ndcg_score": ndcg_score, + "dcg_score": dcg_score } ALL_METRICS = dict() @@ -266,6 +270,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "unnormalized_multilabel_confusion_matrix_sample", "label_ranking_loss", "label_ranking_average_precision_score", + "dcg_score", + "ndcg_score" } # Those metrics don't support multiclass inputs @@ -388,6 +394,10 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): "samples_average_precision_score", "micro_average_precision_score", "coverage_error", "label_ranking_loss", + + "ndcg_score", + "dcg_score", + "label_ranking_average_precision_score", } diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index c202aef1added..03a28c958264a 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -27,6 +27,8 @@ from sklearn.metrics import label_ranking_loss from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_curve +from sklearn.metrics.ranking import _ndcg_sample_scores, _dcg_sample_scores +from sklearn.metrics.ranking import ndcg_score, dcg_score from sklearn.exceptions import UndefinedMetricWarning @@ -1254,6 +1256,125 @@ def test_ranking_loss_ties_handling(): assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 1) +def test_dcg_score(): + _, y_true = make_multilabel_classification(random_state=0, n_classes=10) + y_score = - y_true + 1 + _test_dcg_score_for(y_true, y_score) + y_true, y_score = np.random.RandomState(0).random_sample((2, 100, 10)) + _test_dcg_score_for(y_true, y_score) + + +def _test_dcg_score_for(y_true, y_score): + discount = np.log2(np.arange(y_true.shape[1]) + 2) + ideal = _dcg_sample_scores(y_true, y_true) + score = _dcg_sample_scores(y_true, y_score) + assert (score <= ideal).all() + assert (_dcg_sample_scores(y_true, y_true, k=5) <= ideal).all() + assert ideal.shape == (y_true.shape[0], ) + assert score.shape == (y_true.shape[0], ) + assert ideal == pytest.approx( + (np.sort(y_true)[:, ::-1] / discount).sum(axis=1)) + + +def test_dcg_ties(): + y_true = np.asarray([np.arange(5)]) + y_score = np.zeros(y_true.shape) + dcg = _dcg_sample_scores(y_true, y_score) + dcg_ignore_ties = _dcg_sample_scores(y_true, y_score, ignore_ties=True) + discounts = 1 / np.log2(np.arange(2, 7)) + assert dcg == pytest.approx([discounts.sum() * y_true.mean()]) + assert dcg_ignore_ties == pytest.approx( + [(discounts * y_true[:, ::-1]).sum()]) + y_score[0, 3:] = 1 + dcg = _dcg_sample_scores(y_true, y_score) + dcg_ignore_ties = _dcg_sample_scores(y_true, y_score, ignore_ties=True) + assert dcg_ignore_ties == pytest.approx( + [(discounts * y_true[:, ::-1]).sum()]) + assert dcg == pytest.approx([ + discounts[:2].sum() * y_true[0, 3:].mean() + + discounts[2:].sum() * y_true[0, :3].mean() + ]) + + +def test_ndcg_ignore_ties_with_k(): + a = np.arange(12).reshape((2, 6)) + assert ndcg_score(a, a, k=3, ignore_ties=True) == pytest.approx( + ndcg_score(a, a, k=3, ignore_ties=True)) + + +def test_ndcg_invariant(): + y_true = np.arange(70).reshape(7, 10) + y_score = y_true + np.random.RandomState(0).uniform( + -.2, .2, size=y_true.shape) + ndcg = ndcg_score(y_true, y_score) + ndcg_no_ties = ndcg_score(y_true, y_score, ignore_ties=True) + assert ndcg == pytest.approx(ndcg_no_ties) + assert ndcg == pytest.approx(1.) + y_score += 1000 + assert ndcg_score(y_true, y_score) == pytest.approx(1.) + + +@pytest.mark.parametrize('ignore_ties', [True, False]) +def test_ndcg_toy_examples(ignore_ties): + y_true = 3 * np.eye(7)[:5] + y_score = np.tile(np.arange(6, -1, -1), (5, 1)) + y_score_noisy = y_score + np.random.RandomState(0).uniform( + -.2, .2, size=y_score.shape) + assert _dcg_sample_scores( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( + 3 / np.log2(np.arange(2, 7))) + assert _dcg_sample_scores( + y_true, y_score_noisy, ignore_ties=ignore_ties) == pytest.approx( + 3 / np.log2(np.arange(2, 7))) + assert _ndcg_sample_scores( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( + 1 / np.log2(np.arange(2, 7))) + assert _dcg_sample_scores(y_true, y_score, log_base=10, + ignore_ties=ignore_ties) == pytest.approx( + 3 / np.log10(np.arange(2, 7))) + assert ndcg_score( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( + (1 / np.log2(np.arange(2, 7))).mean()) + assert dcg_score( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( + (3 / np.log2(np.arange(2, 7))).mean()) + y_true = 3 * np.ones((5, 7)) + expected_dcg_score = (3 / np.log2(np.arange(2, 9))).sum() + assert _dcg_sample_scores( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( + expected_dcg_score * np.ones(5)) + assert _ndcg_sample_scores( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx(np.ones(5)) + assert dcg_score( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( + expected_dcg_score) + assert ndcg_score( + y_true, y_score, ignore_ties=ignore_ties) == pytest.approx(1.) + + +def test_ndcg_score(): + _, y_true = make_multilabel_classification(random_state=0, n_classes=10) + y_score = - y_true + 1 + _test_ndcg_score_for(y_true, y_score) + y_true, y_score = np.random.RandomState(0).random_sample((2, 100, 10)) + _test_ndcg_score_for(y_true, y_score) + + +def _test_ndcg_score_for(y_true, y_score): + ideal = _ndcg_sample_scores(y_true, y_true) + score = _ndcg_sample_scores(y_true, y_score) + assert (score <= ideal).all() + all_zero = (y_true == 0).all(axis=1) + assert ideal[~all_zero] == pytest.approx(np.ones((~all_zero).sum())) + assert ideal[all_zero] == pytest.approx(np.zeros(all_zero.sum())) + assert score[~all_zero] == pytest.approx( + _dcg_sample_scores(y_true, y_score)[~all_zero] / + _dcg_sample_scores(y_true, y_true)[~all_zero]) + assert score[all_zero] == pytest.approx(np.zeros(all_zero.sum())) + assert ideal.shape == (y_true.shape[0], ) + assert score.shape == (y_true.shape[0], ) + + def test_partial_roc_auc_score(): # Check `roc_auc_score` for max_fpr != `None` y_true = np.array([0, 0, 1, 1])