From 6e5bacb6cb4f835539e22b5c14ec4bf3eee52051 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 11 Oct 2021 23:18:04 +0200 Subject: [PATCH 01/31] added unit tests --- sklearn/metrics/tests/test_ranking.py | 82 +++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 01de37b189733..a45a0c2e202f6 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -28,6 +28,7 @@ from sklearn.metrics import label_ranking_loss from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_curve +from sklearn.metrics import lift_curve from sklearn.metrics._ranking import _ndcg_sample_scores, _dcg_sample_scores from sklearn.metrics import ndcg_score, dcg_score from sklearn.metrics import top_k_accuracy_score @@ -44,6 +45,7 @@ det_curve, precision_recall_curve, roc_curve, + lift_curve, ] @@ -437,6 +439,86 @@ def test_roc_curve_fpr_tpr_increasing(): assert (np.diff(tpr) < 0).sum() == 0 +def test_lift_curve(): + # Test lift curve function + y_true, _, y_score = make_prediction(binary=True) + lift, percentages, thresholds = lift_curve(y_true, y_score) + assert_almost_equal(np.max(percentages), 100) + assert_almost_equal(np.min(percentages), 0) + assert_almost_equal(np.min(lift), 1) + assert_almost_equal(percentages[0], 0) + assert_almost_equal(thresholds[0], np.max(y_score) + 1) + assert lift.shape == percentages.shape + assert percentages.shape == thresholds.shape + assert np.all(np.diff(percentages, 1) >= 0) + assert np.all(np.diff(thresholds, 1) <= 0) + + +def test_lift_curve_toydata(): + # Binary classification + y_true = [0, 1] + y_score = [0, 1] + lift, percentages, thresholds = lift_curve(y_true, y_score) + assert_array_almost_equal(lift, [2, 2, 1]) + assert_array_almost_equal(percentages, [0, 50, 100]) + assert_array_almost_equal(thresholds, [2, 1, 0]) + + y_true = [0, 1] + y_score = [1, 0] + lift, percentages, thresholds = lift_curve(y_true, y_score) + assert_array_almost_equal(lift, [0, 0, 1]) + assert_array_almost_equal(percentages, [0, 50, 100]) + assert_array_almost_equal(thresholds, [2, 1, 0]) + + y_true = [1, 0] + y_score = [1, 1] + lift, percentages, thresholds = lift_curve(y_true, y_score) + assert_array_almost_equal(lift, [1, 1]) + assert_array_almost_equal(percentages, [0, 100]) + assert_array_almost_equal(thresholds, [2, 1]) + + y_true = [1, 0] + y_score = [1, 0] + lift, percentages, thresholds = lift_curve(y_true, y_score) + assert_array_almost_equal(lift, [2, 2, 1]) + assert_array_almost_equal(percentages, [0, 50, 100]) + assert_array_almost_equal(thresholds, [2, 1, 0]) + + y_true = [1, 0] + y_score = [0.5, 0.5] + lift, percentages, thresholds = lift_curve(y_true, y_score) + assert_array_almost_equal(lift, [1, 1]) + assert_array_almost_equal(percentages, [0, 100]) + assert_array_almost_equal(thresholds, [1.5, 0.5]) + + y_true = [0, 0] + y_score = [0.25, 0.75] + # assert UndefinedMetricWarning because of no positive sample in y_true + expected_message = ( + "No positive samples in y_true, true positive value should be meaningless" + ) + with pytest.warns(UndefinedMetricWarning, match=expected_message): + _, _, _ = roc_curve(y_true, y_score) + + +def test_lift_curve_sample_weight(): + # With weights + y_true = [0, 1, 0, 1, 1, 0, 1, 1] + y_score = [0, 1, 0.5, 0.6, 0.4, 0.1, 0.7, 0.4] + weights = [1, 1, 2, 2, 2, 0, 0, 4] + lift1, percentages1, thresholds1 = lift_curve(y_true, y_score, + sample_weight=weights) + + # With repeats + y_true = [0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + y_score = [0, 1, 0.5, 0.5, 0.6, 0.6, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4] + lift2, percentages2, thresholds2 = lift_curve(y_true, y_score) + + assert_array_almost_equal(lift1, lift2) + assert_array_almost_equal(percentages1, percentages2) + assert_array_almost_equal(thresholds1, thresholds2) + + def test_auc(): # Test Area Under Curve (AUC) computation x = [0, 1] From cd4853d9f5bd6d16f6a1d42c2dcc50e72ec62e0b Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 11 Oct 2021 23:19:13 +0200 Subject: [PATCH 02/31] added first version of lift curve metric function. --- sklearn/metrics/__init__.py | 2 + sklearn/metrics/_ranking.py | 147 ++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index e4339229c5b64..8c215fae5b6e7 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -15,6 +15,7 @@ from ._ranking import precision_recall_curve from ._ranking import roc_auc_score from ._ranking import roc_curve +from ._ranking import lift_curve from ._ranking import top_k_accuracy_score from ._classification import accuracy_score @@ -131,6 +132,7 @@ "jaccard_score", "label_ranking_average_precision_score", "label_ranking_loss", + "lift_curve", "log_loss", "make_scorer", "nan_euclidean_distances", diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index badfff094f7fa..240693c34b80a 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1007,6 +1007,153 @@ def roc_curve( return fpr, tpr, thresholds +def lift_curve( + y_true, y_score, *, pos_label=None, sample_weight=None): + """Compute lift for each percent coverage of the sample. + + Lift is the ratio of positive treatment responses to treatments on a + specific subset of a population, relative to the ratio of positive + treatment responses to treatments on a random subset of the population. + + This metric is only for binary classification. + + Parameters + ---------- + y_true : ndarray of shape (n_samples,) + True binary labels. If labels are not either {-1, 1} or {0, 1}, then + pos_label should be explicitly given. + + y_score : ndarray of shape (n_samples,) + Target scores, can either be probability estimates of the positive + class, confidence values, or non-thresholded measure of decisions + (as returned by "decision_function" on some classifiers). + + pos_label : int or str, default=None + The label of the positive class. + When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1}, + ``pos_label`` is set to 1, otherwise an error will be raised. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + lift : ndarray of shape (>2,) + Decreasing lift values such that element i is the lift a treatment on + a subset with score >= `thresholds[i]`. + + percentages : ndarray of shape = (n_thresholds,) + Increasing percentages on population included in the treatment. + `percentages[0]` represents no instances being predicted + and is arbitrarily set to `0`. + + thresholds : ndarray of shape = (n_thresholds,) + Decreasing thresholds on the decision function used to compute + lift. `thresholds[0]` represents no instances being predicted + and is arbitrarily set to `max(y_score) + 1`. + + References + ---------- + .. [1] `Wikipedia entry for the lift metric + `_ + + .. [2] `IBM's SPSS page on gain and lift + `_ + + Examples + -------- + >>> import numpy as np + >>> from sklearn import metrics + >>> y = np.array([1, 2, 1, 2]) + >>> scores = np.array([0.1, 0.4, 0.3, 0.8]) + >>> lift, percentages, threshs = metrics.lift_curve(y, scores, pos_label=2) + >>> lift + array([0. , 2. , 2. , 1.33333333, 1.]) + >>> percentages + array([0., 25., 50., 75., 100.]) + >>> threshs + array([1., 0.8, 0.4, 0.3, 0.1]) + """ + + fps, tps, thresholds = _binary_clf_curve( + y_true, y_score, + pos_label=pos_label, + sample_weight=sample_weight + ) + + # False negatives + fns = tps[-1] - tps + # Sample counts + n_samples = fps[-1] + tps[-1] + + # Lift & percentages + lift = n_samples * tps / ((fps + tps) * (fns + tps)) + percentages = 100 * (fps + tps) / n_samples + + # Insert a 0 percentage point + lift = np.insert(lift, 0, [lift[0]]) + percentages = np.insert(percentages, 0, [0]) + thresholds = np.insert(thresholds, 0, [thresholds.max() + 1]) + + return lift, percentages, thresholds + # Check to make sure y_true is valid + y_type = type_of_target(y_true) + if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)): + raise ValueError("{0} format is not supported".format(y_type)) + + # Make sure the arrays are of the same length + check_consistent_length(y_true, y_score) + y_true = column_or_1d(y_true) + y_score = column_or_1d(y_score) + assert_all_finite(y_true) + assert_all_finite(y_score) + + pos_label = _check_pos_label_consistency(pos_label, y_true) + + # Filter out zero-weighted samples, as they should not impact the result + if sample_weight is not None: + sample_weight = column_or_1d(sample_weight) + sample_weight = _check_sample_weight(sample_weight, y_true) + nonzero_weight_mask = sample_weight != 0 + y_true = y_true[nonzero_weight_mask] + y_score = y_score[nonzero_weight_mask] + sample_weight = sample_weight[nonzero_weight_mask] + + # make y_true a boolean vector + y_true = y_true == pos_label + + # sort scores and corresponding truth values + desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] + y_score = y_score[desc_score_indices] + y_true = y_true[desc_score_indices] + if sample_weight is not None: + weight = sample_weight[desc_score_indices] + weight = weight / weight.sum() + else: + weight = 1.0 + + # Thresholds + thresholds = np.insert(y_score, 0, [y_score.max() + 1]) + + # Percentages + percentages = np.arange(start=1, stop=len(y_true) + 1) + percentages = 100 * percentages * weight / float(len(y_true)) + percentages = np.insert(percentages, 0, [0]) + + # Lift + gain = stable_cumsum(y_true * weight) / np.arange(1, len(y_true) + 1) + lift = gain / (y_true * weight).mean() + lift = np.insert(lift, 0, [lift[0]]) + + if np.any(lift <= 0): + warnings.warn( + "No positive samples in y_true, lift is meaningless", + UndefinedMetricWarning, + ) + + return lift, percentages, thresholds + + def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None): """Compute ranking-based average precision. From 7e1cdcca0f44726f11b7ff001495e5b8004a5b08 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 11 Oct 2021 23:20:09 +0200 Subject: [PATCH 03/31] removed old code after return statement of lift curve function --- sklearn/metrics/_ranking.py | 56 ------------------------------------- 1 file changed, 56 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 240693c34b80a..efda40b035345 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1096,62 +1096,6 @@ def lift_curve( thresholds = np.insert(thresholds, 0, [thresholds.max() + 1]) return lift, percentages, thresholds - # Check to make sure y_true is valid - y_type = type_of_target(y_true) - if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)): - raise ValueError("{0} format is not supported".format(y_type)) - - # Make sure the arrays are of the same length - check_consistent_length(y_true, y_score) - y_true = column_or_1d(y_true) - y_score = column_or_1d(y_score) - assert_all_finite(y_true) - assert_all_finite(y_score) - - pos_label = _check_pos_label_consistency(pos_label, y_true) - - # Filter out zero-weighted samples, as they should not impact the result - if sample_weight is not None: - sample_weight = column_or_1d(sample_weight) - sample_weight = _check_sample_weight(sample_weight, y_true) - nonzero_weight_mask = sample_weight != 0 - y_true = y_true[nonzero_weight_mask] - y_score = y_score[nonzero_weight_mask] - sample_weight = sample_weight[nonzero_weight_mask] - - # make y_true a boolean vector - y_true = y_true == pos_label - - # sort scores and corresponding truth values - desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] - y_score = y_score[desc_score_indices] - y_true = y_true[desc_score_indices] - if sample_weight is not None: - weight = sample_weight[desc_score_indices] - weight = weight / weight.sum() - else: - weight = 1.0 - - # Thresholds - thresholds = np.insert(y_score, 0, [y_score.max() + 1]) - - # Percentages - percentages = np.arange(start=1, stop=len(y_true) + 1) - percentages = 100 * percentages * weight / float(len(y_true)) - percentages = np.insert(percentages, 0, [0]) - - # Lift - gain = stable_cumsum(y_true * weight) / np.arange(1, len(y_true) + 1) - lift = gain / (y_true * weight).mean() - lift = np.insert(lift, 0, [lift[0]]) - - if np.any(lift <= 0): - warnings.warn( - "No positive samples in y_true, lift is meaningless", - UndefinedMetricWarning, - ) - - return lift, percentages, thresholds def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None): From e0061e94cdc482de91865cb1bb0441e385172a4d Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 11 Oct 2021 23:25:56 +0200 Subject: [PATCH 04/31] fixed outdated example in docs of lift curve --- sklearn/metrics/_ranking.py | 4 ++-- sklearn/metrics/tests/test_ranking.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index efda40b035345..ee8b91f160b66 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1068,11 +1068,11 @@ def lift_curve( >>> scores = np.array([0.1, 0.4, 0.3, 0.8]) >>> lift, percentages, threshs = metrics.lift_curve(y, scores, pos_label=2) >>> lift - array([0. , 2. , 2. , 1.33333333, 1.]) + array([2. , 2. , 2. , 1.33333333, 1.]) >>> percentages array([0., 25., 50., 75., 100.]) >>> threshs - array([1., 0.8, 0.4, 0.3, 0.1]) + array([1.8, 0.8, 0.4, 0.3, 0.1] """ fps, tps, thresholds = _binary_clf_curve( diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index a45a0c2e202f6..5024b23580431 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -456,6 +456,13 @@ def test_lift_curve(): def test_lift_curve_toydata(): # Binary classification + y_true = np.array([1, 2, 1, 2]) + y_score = np.array([0.1, 0.4, 0.3, 0.8]) + lift, percentages, thresholds = lift_curve(y_true, y_score, pos_label=2) + assert_array_almost_equal(lift, [2, 2, 2, 1.333333, 1]) + assert_array_almost_equal(percentages, [0, 25, 50, 75, 100]) + assert_array_almost_equal(thresholds, [1.8, 0.8, 0.4, 0.3, 0.1]) + y_true = [0, 1] y_score = [0, 1] lift, percentages, thresholds = lift_curve(y_true, y_score) From b04181d463b6df3f240a1580c9738785e7850668 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 11 Oct 2021 23:46:26 +0200 Subject: [PATCH 05/31] added testing pos_label of lift curve --- sklearn/metrics/tests/test_ranking.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 5024b23580431..534224f3daa49 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -454,6 +454,24 @@ def test_lift_curve(): assert np.all(np.diff(thresholds, 1) <= 0) +def test_lift_curve_pos_label(): + # Binary classification + y_true = np.array([1, 2, 1, 2]) + y_score = np.array([0.1, 0.4, 0.3, 0.8]) + lift, percentages, thresholds = lift_curve(y_true, y_score, pos_label=2) + assert_array_almost_equal(lift, [2, 2, 2, 1.333333, 1]) + assert_array_almost_equal(percentages, [0, 25, 50, 75, 100]) + assert_array_almost_equal(thresholds, [1.8, 0.8, 0.4, 0.3, 0.1]) + + # Binary classification + y_true = np.array([1, 2, 1, 2]) + y_score = np.array([0.1, 0.4, 0.3, 0.8]) + lift, percentages, thresholds = lift_curve(y_true, y_score, pos_label=1) + assert_array_almost_equal(lift, [0, 0, 0, 0.666666, 1]) + assert_array_almost_equal(percentages, [0, 25, 50, 75, 100]) + assert_array_almost_equal(thresholds, [1.8, 0.8, 0.4, 0.3, 0.1]) + + def test_lift_curve_toydata(): # Binary classification y_true = np.array([1, 2, 1, 2]) From 0e171aea2635175a5fd68191417782e396abb58e Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Tue, 12 Oct 2021 09:38:08 +0200 Subject: [PATCH 06/31] added see also section of the lift curve function --- sklearn/metrics/_ranking.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index ee8b91f160b66..5a6f1b841069c 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1007,8 +1007,7 @@ def roc_curve( return fpr, tpr, thresholds -def lift_curve( - y_true, y_score, *, pos_label=None, sample_weight=None): +def lift_curve(y_true, y_score, *, pos_label=None, sample_weight=None): """Compute lift for each percent coverage of the sample. Lift is the ratio of positive treatment responses to treatments on a @@ -1051,7 +1050,14 @@ def lift_curve( Decreasing thresholds on the decision function used to compute lift. `thresholds[0]` represents no instances being predicted and is arbitrarily set to `max(y_score) + 1`. - + + See Also + -------- + det_curve: Compute error rates for different probability thresholds. + roc_auc_score : Compute the area under the ROC curve. + roc_curve : Compute Receiver operating characteristic (ROC) curve. + precision_recall_curve : Compute precision-recall curve. + References ---------- .. [1] `Wikipedia entry for the lift metric From 425f15c2c2c252f2b867de5138ac44d32b29ab1e Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Tue, 12 Oct 2021 14:59:22 +0200 Subject: [PATCH 07/31] added lift_score tests --- sklearn/metrics/tests/test_classification.py | 54 ++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 992fb99e8c0e0..c7da5bc8df917 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -34,6 +34,7 @@ from sklearn.metrics import hamming_loss from sklearn.metrics import hinge_loss from sklearn.metrics import jaccard_score +from sklearn.metrics import lift_score from sklearn.metrics import log_loss from sklearn.metrics import matthews_corrcoef from sklearn.metrics import precision_recall_fscore_support @@ -418,6 +419,59 @@ def test_precision_recall_f_unused_pos_label(): ) +def test_lift_score(): + y_true = [0, 2, 2, 0, 0, 2, 2] + y_pred = [0, 2, 2, 0, 2, 0, 2] + assert_almost_equal(1.3125, lift_score(y_true, y_pred, pos_label=2)) + + assert 0.0 == lift_score([], []) + assert 1.0 == lift_score([1, 1], [1, 1]) + assert 0.0 == lift_score([1, 1], [0, 0]) + assert 1.0 == lift_score([1, 1], [1, 0]) + assert 1.0 == lift_score([1, 1], [0, 1]) + assert 2.0 == lift_score([1, 0], [1, 0]) + assert 0.0 == lift_score([1, 0], [0, 1]) + assert 0.0 == lift_score([0, 0], [1, 1]) + + assert 1.0 == lift_score([1, 1], [1, 1], zero_division=1) + assert 1.0 == lift_score([1, 1], [0, 0], zero_division=1) + assert 1.0 == lift_score([1, 1], [1, 0], zero_division=1) + assert 1.0 == lift_score([1, 1], [0, 1], zero_division=1) + assert 2.0 == lift_score([1, 0], [1, 0], zero_division=1) + assert 0.0 == lift_score([1, 0], [0, 1], zero_division=1) + assert 1.0 == lift_score([0, 0], [1, 1], zero_division=1) + + assert 0.0 == lift_score([1, 1], [1, 1], pos_label=0) + assert 0.0 == lift_score([1, 1], [0, 0], pos_label=0) + assert 0.0 == lift_score([1, 1], [1, 0], pos_label=0) + assert 0.0 == lift_score([1, 1], [0, 1], pos_label=0) + assert 2.0 == lift_score([1, 0], [1, 0], pos_label=0) + assert 0.0 == lift_score([1, 0], [0, 1], pos_label=0) + assert 0.0 == lift_score([0, 0], [1, 1], pos_label=0) + + +def test_lift_score_sample_weight(): + # With weights + y_true = [0, 1, 0, 1, 1, 0, 1, 1] + y_pred = [0, 1, 1, 1, 0, 0, 1, 1] + weights = [1, 1, 2, 2, 2, 0, 0, 4] + lift1 = lift_score(y_true, y_pred, sample_weight=weights) + + # With repeats + y_true = [0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + y_pred = [0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1] + lift2 = lift_score(y_true, y_pred) + + assert_array_almost_equal(lift1, lift2) + + +def test_lift_score_warning(): + with pytest.warns(UndefinedMetricWarning): + lift_score( + [1, 1, 1], [0, 0, 0], zero_division="warn" + ) + + def test_confusion_matrix_binary(): # Test confusion matrix - binary classification case y_true, y_pred, _ = make_prediction(binary=True) From 72423a67481f325f779e7034b08b1697a549adbc Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Tue, 12 Oct 2021 14:59:44 +0200 Subject: [PATCH 08/31] added lift_score function --- sklearn/metrics/__init__.py | 1 + sklearn/metrics/_classification.py | 108 +++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 8c215fae5b6e7..12ca9255798ca 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -29,6 +29,7 @@ from ._classification import hinge_loss from ._classification import jaccard_score from ._classification import log_loss +from ._classification import lift_score from ._classification import matthews_corrcoef from ._classification import precision_recall_fscore_support from ._classification import precision_score diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index b4316053c0f74..c478407f170b3 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1748,6 +1748,114 @@ def precision_score( return p +def lift_score( + y_true, + y_pred, + *, + pos_label=1, + sample_weight=None, + zero_division="warn", +): + """Compute the lift. + + The lift is the ratio ``(n * tp) / ((tp + fp) * (tp + fn))`` where ``tp`` + is the number of true positives, ``fp`` the number of false positives and + ``fn`` the number of false negatives. The lift is intuitively the relative + positive class precision imporvement over selecting a random subset and + labeling it positive. + + Another way to think of lift is the ratio ``precision / pr`` where ``pr`` + is the positive rate in the true set. + + The worst value is 0 but lift does not have an upper bound. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + y_true : 1d array-like, or label indicator array / sparse matrix + Ground truth (correct) target values. + + y_pred : 1d array-like, or label indicator array / sparse matrix + Estimated targets as returned by a classifier. + + pos_label : str or int, default=1 + The class to report if ``average='binary'`` and the data is binary. + If the data are multiclass or multilabel, this will be ignored; + setting ``labels=[pos_label]`` and ``average != 'binary'`` will report + scores for that label only. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + zero_division : "warn", 0 or 1, default="warn" + Sets the value to return when there is a zero division. If set to + "warn", this acts as 0, but warnings are also raised. + + Returns + ------- + lift : float (if average is not None) or array of float of shape \ + (n_unique_labels,) + lift of the positive class in binary classification or weighted + average of the lift of each class for the multiclass task. + + See Also + -------- + lift_curve, precision_recall_curve + + Notes + ----- + When ``true positive + false positive == 0``, lift returns 0 and + raises ``UndefinedMetricWarning``. This behavior can be + modified with ``zero_division``. which is passed to the precision + function. + Lift is only possible on binary data. + + Examples + -------- + >>> from sklearn.metrics import lift_score + >>> y_true = [0, 2, 2, 0, 0, 2, 2] + >>> y_pred = [0, 2, 2, 0, 2, 0, 2] + >>> lift_score(y_true, y_pred) + 1.3125 + """ + # Precision + p, _, _, _ = precision_recall_fscore_support( + y_true, + y_pred, + average="binary", + pos_label=pos_label, + warn_for=("precision",), + sample_weight=sample_weight, + zero_division=zero_division, + ) + + # True labels + y_true = column_or_1d(y_true) + y_true = y_true == pos_label + + # Sample weights + if sample_weight is None: + sample_weight = np.ones(y_true.shape[0], dtype=np.int64) + else: + sample_weight = column_or_1d(sample_weight) + check_consistent_length(y_true, sample_weight) + + # Positive rate and lift + pr = (y_true * sample_weight).sum() / sample_weight.sum() + + # Lift + if pr == 0 or np.isnan(pr): + zero_division_value = np.float64(1.0) + if zero_division in ["warn", 0]: + zero_division_value = np.float64(0.0) + lift = zero_division_value + else: + lift = p / pr + + return lift + + def recall_score( y_true, y_pred, From 5fbc53c3eb25585bce47b242d0f382bb1199b7e1 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Tue, 12 Oct 2021 18:07:05 +0200 Subject: [PATCH 09/31] bug and doc fixes --- sklearn/metrics/__init__.py | 1 + sklearn/metrics/_classification.py | 4 ++-- sklearn/metrics/_ranking.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 12ca9255798ca..d1471d12545a5 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -134,6 +134,7 @@ "label_ranking_average_precision_score", "label_ranking_loss", "lift_curve", + "lift_score", "log_loss", "make_scorer", "nan_euclidean_distances", diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index c478407f170b3..985abdebbd70f 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1814,8 +1814,8 @@ def lift_score( Examples -------- >>> from sklearn.metrics import lift_score - >>> y_true = [0, 2, 2, 0, 0, 2, 2] - >>> y_pred = [0, 2, 2, 0, 2, 0, 2] + >>> y_true = [0, 1, 1, 0, 0, 1, 1] + >>> y_pred = [0, 1, 1, 0, 1, 0, 1] >>> lift_score(y_true, y_pred) 1.3125 """ diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 5a6f1b841069c..98455c1b8e9f4 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1037,7 +1037,7 @@ def lift_curve(y_true, y_score, *, pos_label=None, sample_weight=None): Returns ------- - lift : ndarray of shape (>2,) + lift : ndarray of shape (n_thresholds,) Decreasing lift values such that element i is the lift a treatment on a subset with score >= `thresholds[i]`. @@ -1074,7 +1074,7 @@ def lift_curve(y_true, y_score, *, pos_label=None, sample_weight=None): >>> scores = np.array([0.1, 0.4, 0.3, 0.8]) >>> lift, percentages, threshs = metrics.lift_curve(y, scores, pos_label=2) >>> lift - array([2. , 2. , 2. , 1.33333333, 1.]) + array([2. , 2. , 2. , 1.33333333 , 1.]) >>> percentages array([0., 25., 50., 75., 100.]) >>> threshs From 3d7813675e29d9251496f08c52952d2ec16a5dd6 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Tue, 12 Oct 2021 18:07:28 +0200 Subject: [PATCH 10/31] added first documentation of lift_score and lift_curve --- doc/modules/classes.rst | 2 ++ doc/modules/model_evaluation.rst | 33 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index b7000bcf7cbb2..6592db93cb1dc 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -960,6 +960,8 @@ details. metrics.hamming_loss metrics.hinge_loss metrics.jaccard_score + metrics.lift_curve + metrics.lift_score metrics.log_loss metrics.matthews_corrcoef metrics.multilabel_confusion_matrix diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 3dd7313bb9b59..371ee2c25e8d7 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -68,6 +68,7 @@ Scoring Function 'f1_macro' :func:`metrics.f1_score` macro-averaged 'f1_weighted' :func:`metrics.f1_score` weighted average 'f1_samples' :func:`metrics.f1_score` by multilabel sample +'lift' :func:`metrics.lift_score` 'neg_log_loss' :func:`metrics.log_loss` requires ``predict_proba`` support 'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1' 'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1' @@ -308,6 +309,7 @@ Some of these are restricted to the binary classification case: precision_recall_curve roc_curve det_curve + lift_curve Others also work in the multiclass case: @@ -334,6 +336,7 @@ Some also work in the multilabel case: hamming_loss jaccard_score log_loss + lift_score multilabel_confusion_matrix precision_recall_fscore_support precision_score @@ -739,6 +742,36 @@ In the multilabel case with binary label indicators:: or superset of the true labels will give a Hamming loss between zero and one, exclusive. +.. _lift_score: + +Lift +---- + +`Lift `_ can be understood in +different ways. One way is as the ratio of the positive responses of a targeted +treatment of a subset of the dataset relative to the ratio of positive responses +in the dataset as a whole. + +Lift can also be understood as a kind of normalised precision of the positive class. + +.. math:: + + Lift = \frac{n \times tp}{(tp + fp) \times (tp + fn)}, + +.. math:: + + Lift = \frac{Precision}{pr} + +where :math:`tp`, :math:`fp`, :math:`fn`, :math:`n` and :math:`pr` are the true positive count, false positive count, false negative count, dataset size and positive rate respectively. + +Here is an example showing how to calculate lift:: + + >>> from sklearn.metrics import lift_score + >>> y_pred = [1, 1, 1, 1, 1, 2, 2, 2] + >>> y_true = [1, 1, 1, 2, 2, 1, 2, 2] + >>> lift_score(y_true, y_pred) + 1.2 + .. _precision_recall_f_measure_metrics: Precision, recall and F-measures From 942dbd4ece48c92b515ccea930ccae36d217a44c Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 10:58:58 +0200 Subject: [PATCH 11/31] Added LiftCurveDisplay to plot lift curves. Same struture as other Display classes --- sklearn/metrics/_plot/lift_curve.py | 473 ++++++++++++++++++++++++++++ 1 file changed, 473 insertions(+) create mode 100644 sklearn/metrics/_plot/lift_curve.py diff --git a/sklearn/metrics/_plot/lift_curve.py b/sklearn/metrics/_plot/lift_curve.py new file mode 100644 index 0000000000000..2e874ebfdb55a --- /dev/null +++ b/sklearn/metrics/_plot/lift_curve.py @@ -0,0 +1,473 @@ +import scipy as sp + +from .base import _get_response + +from .. import lift_curve +from .._base import _check_pos_label_consistency + +from ...utils import check_matplotlib_support +from ...utils import deprecated + + +class LiftCurveDisplay: + """Lift curve visualization. + + It is recommend to use :func:`~sklearn.metrics.LiftCurveDisplay.from_estimator` + or :func:`~sklearn.metrics.LiftCurveDisplay.from_predictions` to create a + visualizer. All parameters are stored as attributes. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + lift : ndarray + lift. + + percentages : ndarray + percentage of population treated (classified positive). + + estimator_name : str, default=None + Name of estimator. If None, the estimator name is not shown. + + pos_label : str or int, default=None + The label of the positive class. + + Attributes + ---------- + line_ : matplotlib Artist + Lift Curve. + + ax_ : matplotlib Axes + Axes with Lift Curve. + + figure_ : matplotlib Figure + Figure containing the curve. + + See Also + -------- + lift_curve : Compute lift scores for different percentage of population + treated (classied positive). + LiftCurveDisplay.from_estimator : Plot lift curve given an estimator and + some data. + LiftCurveDisplay.from_predictions : Plot lift curve given the true and + predicted values. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import lift_curve, LiftCurveDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(n_samples=1000, random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, test_size=0.4, random_state=0) + >>> clf = LogisticRegression(random_state=0).fit(X_train, y_train) + >>> y_prob = clf.decision_function(X_test) + >>> lift, percentages, _ = lift_curve(y_test, y_prob) + >>> display = LiftCurveDisplay( + ... lift=lift, percentages=percentages, estimator_name="LogisticRegression" + ... ) + >>> display.plot() + <...> + >>> plt.show() + """ + + def __init__(self, *, lift, percentages, estimator_name=None, pos_label=None): + self.lift = lift + self.percentages = percentages + self.estimator_name = estimator_name + self.pos_label = pos_label + + @classmethod + def from_estimator( + cls, + estimator, + X, + y, + *, + sample_weight=None, + response_method="auto", + pos_label=None, + name=None, + ax=None, + **kwargs, + ): + """Plot lift curve given an estimator and data. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + response_method : {'predict_proba', 'decision_function', 'auto'} \ + default='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the predicted target response. If set + to 'auto', :term:`predict_proba` is tried first and if it does not + exist :term:`decision_function` is tried next. + + pos_label : str or int, default=None + The label of the positive class. When `pos_label=None`, if `y_true` + is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an + error will be raised. + + name : str, default=None + Name of lift curve for labeling. If `None`, use the name of the + estimator. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + **kwargs : dict + Additional keywords arguments passed to matplotlib `plot` function. + + Returns + ------- + display : :class:`~sklearn.metrics.LiftCurveDisplay` + Object that stores computed values. + + See Also + -------- + lift_curve : Compute lift scores for different treatment percentages + (percent of positively classified data points). + LiftCurveDisplay.from_predictions : Plot lift curve given the true and + predicted values. + plot_roc_curve : Plot Receiver operating characteristic (ROC) curve. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import LiftCurveDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(n_samples=1000, random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, test_size=0.4, random_state=0) + >>> clf = LogisticRegression(random_state=0).fit(X_train, y_train) + >>> LiftCurveDisplay.from_estimator( + ... clf, X_test, y_test) + <...> + >>> plt.show() + """ + check_matplotlib_support(f"{cls.__name__}.from_estimator") + + name = estimator.__class__.__name__ if name is None else name + + y_pred, pos_label = _get_response( + X, + estimator, + response_method, + pos_label=pos_label, + ) + + return cls.from_predictions( + y_true=y, + y_pred=y_pred, + sample_weight=sample_weight, + name=name, + ax=ax, + pos_label=pos_label, + **kwargs, + ) + + @classmethod + def from_predictions( + cls, + y_true, + y_pred, + *, + sample_weight=None, + pos_label=None, + name=None, + ax=None, + **kwargs, + ): + """Plot lift curve given the true and + predicted values. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True labels. + + y_pred : array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive + class, confidence values, or non-thresholded measure of decisions + (as returned by `decision_function` on some classifiers). + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + pos_label : str or int, default=None + The label of the positive class. When `pos_label=None`, if `y_true` + is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an + error will be raised. + + name : str, default=None + Name of lift curve for labeling. If `None`, name will be set to + `"Classifier"`. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + **kwargs : dict + Additional keywords arguments passed to matplotlib `plot` function. + + Returns + ------- + display : :class:`~sklearn.metrics.LiftCurveDisplay` + Object that stores computed values. + + See Also + -------- + lift_curve : Compute lift scores for different treatment percentages + (percent of positively classified data points). + LiftCurveDisplay.from_estimator : Plot lift curve given an estimator and + some data. + plot_roc_curve : Plot Receiver operating characteristic (ROC) curve. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import LiftCurveDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(n_samples=1000, random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, test_size=0.4, random_state=0) + >>> clf = LogisticRegression(random_state=0).fit(X_train, y_train) + >>> y_pred = clf.decision_function(X_test) + >>> LiftCurveDisplay.from_predictions( + ... y_test, y_pred) + <...> + >>> plt.show() + """ + check_matplotlib_support(f"{cls.__name__}.from_predictions") + lift, percentages, _ = lift_curve( + y_true, + y_pred, + pos_label=pos_label, + sample_weight=sample_weight, + ) + + pos_label = _check_pos_label_consistency(pos_label, y_true) + name = "Classifier" if name is None else name + + viz = LiftCurveDisplay( + lift=lift, + percentages=percentages, + estimator_name=name, + pos_label=pos_label, + ) + + return viz.plot(ax=ax, name=name, **kwargs) + + def plot(self, ax=None, *, name=None, **kwargs): + """Plot visualization. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + name : str, default=None + Name of lift curve for labeling. If `None`, use `estimator_name` if + it is not `None`, otherwise no labeling is shown. + + **kwargs : dict + Additional keywords arguments passed to matplotlib `plot` function. + + Returns + ------- + display : :class:`~sklearn.metrics.plot.LiftCurveDisplay` + Object that stores computed values. + """ + check_matplotlib_support("LiftCurveDisplay.plot") + + name = self.estimator_name if name is None else name + line_kwargs = {} if name is None else {"label": name} + line_kwargs.update(**kwargs) + + import matplotlib.pyplot as plt + + if ax is None: + _, ax = plt.subplots() + + (self.line_,) = ax.plot( + sp.stats.norm.ppf(self.percentages), + sp.stats.norm.ppf(self.lift), + **line_kwargs, + ) + info_pos_label = ( + f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" + ) + + xlabel = "Percentage" + info_pos_label + ylabel = "Lift" + info_pos_label + ax.set(xlabel=xlabel, ylabel=ylabel) + + if "label" in line_kwargs: + ax.legend(loc="lower right") + + ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999] + tick_locations = sp.stats.norm.ppf(ticks) + tick_labels = [ + "{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s) + for s in ticks + ] + ax.set_xticks(tick_locations) + ax.set_xticklabels(tick_labels) + ax.set_xlim(-3, 3) + ax.set_yticks(tick_locations) + ax.set_yticklabels(tick_labels) + ax.set_ylim(-3, 3) + + self.ax_ = ax + self.figure_ = ax.figure + return self + + +@deprecated( + "Function plot_lift_curve is deprecated in 1.0 and will be " + "removed in 1.2. Use one of the class methods: " + "LiftCurveDisplay.from_predictions or " + "LiftCurveDisplay.from_estimator." +) +def plot_lift_curve( + estimator, + X, + y, + *, + sample_weight=None, + response_method="auto", + name=None, + ax=None, + pos_label=None, + **kwargs, +): + """Plot lift curve. + + Extra keyword arguments will be passed to matplotlib's `plot`. + + Read more in the :ref:`User Guide `. + + .. deprecated:: 1.0 + `plot_lift_curve` is deprecated in 1.0 and will be removed in + 1.2. Use one of the following class methods: + :func:`~sklearn.metrics.LiftCurveDisplay.from_predictions` or + :func:`~sklearn.metrics.LiftCurveDisplay.from_estimator`. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + response_method : {'predict_proba', 'decision_function', 'auto'} \ + default='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the predicted target response. If set to + 'auto', :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + name : str, default=None + Name of lift curve for labeling. If `None`, use the name of the + estimator. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + pos_label : str or int, default=None + The label of the positive class. + When `pos_label=None`, if `y_true` is in {-1, 1} or {0, 1}, + `pos_label` is set to 1, otherwise an error will be raised. + + **kwargs : dict + Additional keywords arguments passed to matplotlib `plot` function. + + Returns + ------- + display : :class:`~sklearn.metrics.LiftCurveDisplay` + Object that stores computed values. + + See Also + -------- + lift_curve : Compute lift scores for different treatment percentages + (percent of positively classified data points). + LiftCurveDisplay : lift curve visualization. + LiftCurveDisplay.from_estimator : Plot lift curve given an estimator and + some data. + LiftCurveDisplay.from_predictions : Plot lift curve given the true and + predicted labels. + RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic + (ROC) curve given an estimator and some data. + RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic + (ROC) curve given the true and predicted values. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import plot_lift_curve + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_regression import LogisticRegression + >>> X, y = make_classification(n_samples=1000, random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, test_size=0.4, random_state=0) + >>> clf = LogisticRegression(random_state=0).fit(X_train, y_train) + >>> plot_lift_curve(clf, X_test, y_test) # doctest: +SKIP + <...> + >>> plt.show() + """ + check_matplotlib_support("plot_lift_curve") + + y_pred, pos_label = _get_response( + X, estimator, response_method, pos_label=pos_label + ) + + lift, percentages, _ = lift_curve( + y, + y_pred, + pos_label=pos_label, + sample_weight=sample_weight, + ) + + name = estimator.__class__.__name__ if name is None else name + + viz = LiftCurveDisplay( + lift=lift, + percentages=percentages, + estimator_name=name, + pos_label=pos_label, + ) + + return viz.plot(ax=ax, name=name, **kwargs) From 209ccf51ec5cd782ab7673a1a7048e59ae580c3c Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 14:04:16 +0200 Subject: [PATCH 12/31] added lift display function for plotting with the tests --- sklearn/metrics/__init__.py | 4 + sklearn/metrics/_plot/lift_curve.py | 23 ++-- .../_plot/tests/test_lift_curve_display.py | 108 ++++++++++++++++++ .../_plot/tests/test_plot_lift_curve.py | 84 ++++++++++++++ 4 files changed, 203 insertions(+), 16 deletions(-) create mode 100644 sklearn/metrics/_plot/tests/test_lift_curve_display.py create mode 100644 sklearn/metrics/_plot/tests/test_plot_lift_curve.py diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index d1471d12545a5..03bfd97a6572f 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -88,6 +88,8 @@ from ._plot.det_curve import plot_det_curve from ._plot.det_curve import DetCurveDisplay +from ._plot.lift_curve import plot_lift_curve +from ._plot.lift_curve import LiftCurveDisplay from ._plot.roc_curve import plot_roc_curve from ._plot.roc_curve import RocCurveDisplay from ._plot.precision_recall_curve import plot_precision_recall_curve @@ -133,6 +135,7 @@ "jaccard_score", "label_ranking_average_precision_score", "label_ranking_loss", + "liftCurveDisplay", "lift_curve", "lift_score", "log_loss", @@ -161,6 +164,7 @@ "pairwise_kernels", "plot_confusion_matrix", "plot_det_curve", + "plot_lift_curve", "plot_precision_recall_curve", "plot_roc_curve", "PrecisionRecallDisplay", diff --git a/sklearn/metrics/_plot/lift_curve.py b/sklearn/metrics/_plot/lift_curve.py index 2e874ebfdb55a..7fa3b3c06b19a 100644 --- a/sklearn/metrics/_plot/lift_curve.py +++ b/sklearn/metrics/_plot/lift_curve.py @@ -1,5 +1,3 @@ -import scipy as sp - from .base import _get_response from .. import lift_curve @@ -314,33 +312,26 @@ def plot(self, ax=None, *, name=None, **kwargs): _, ax = plt.subplots() (self.line_,) = ax.plot( - sp.stats.norm.ppf(self.percentages), - sp.stats.norm.ppf(self.lift), + self.percentages, + self.lift, **line_kwargs, ) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" ) - xlabel = "Percentage" + info_pos_label + xlabel = "Positive Rate" + info_pos_label ylabel = "Lift" + info_pos_label ax.set(xlabel=xlabel, ylabel=ylabel) if "label" in line_kwargs: ax.legend(loc="lower right") - ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999] - tick_locations = sp.stats.norm.ppf(ticks) - tick_labels = [ - "{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s) - for s in ticks - ] + tick_locations = [0, 20, 40, 60, 80, 100] + tick_labels = ["{}%".format(s) for s in tick_locations] ax.set_xticks(tick_locations) ax.set_xticklabels(tick_labels) - ax.set_xlim(-3, 3) - ax.set_yticks(tick_locations) - ax.set_yticklabels(tick_labels) - ax.set_ylim(-3, 3) + ax.set_xlim(0, 100) self.ax_ = ax self.figure_ = ax.figure @@ -439,7 +430,7 @@ def plot_lift_curve( >>> from sklearn.datasets import make_classification >>> from sklearn.metrics import plot_lift_curve >>> from sklearn.model_selection import train_test_split - >>> from sklearn.linear_regression import LogisticRegression + >>> from sklearn.linear_model import LogisticRegression >>> X, y = make_classification(n_samples=1000, random_state=0) >>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, test_size=0.4, random_state=0) diff --git a/sklearn/metrics/_plot/tests/test_lift_curve_display.py b/sklearn/metrics/_plot/tests/test_lift_curve_display.py new file mode 100644 index 0000000000000..3a20c63231f4b --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_lift_curve_display.py @@ -0,0 +1,108 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose + +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression + +from sklearn.metrics import lift_curve +from sklearn.metrics import LiftCurveDisplay + + +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +@pytest.mark.parametrize("with_sample_weight", [True, False]) +@pytest.mark.parametrize("with_strings", [True, False]) +def test_lift_curve_display( + pyplot, constructor_name, response_method, with_sample_weight, with_strings +): + X, y = load_iris(return_X_y=True) + # Binarize the data with only the two first classes + X, y = X[y < 2], y[y < 2] + + pos_label = None + if with_strings: + y = np.array(["c", "b"])[y] + pos_label = "c" + + if with_sample_weight: + rng = np.random.RandomState(42) + sample_weight = rng.randint(1, 4, size=(X.shape[0])) + else: + sample_weight = None + + lr = LogisticRegression() + lr.fit(X, y) + y_pred = getattr(lr, response_method)(X) + if y_pred.ndim == 2: + y_pred = y_pred[:, 1] + + # safe guard for the binary if/else construction + assert constructor_name in ("from_estimator", "from_predictions") + + common_kwargs = { + "name": lr.__class__.__name__, + "alpha": 0.8, + "sample_weight": sample_weight, + "pos_label": pos_label, + } + if constructor_name == "from_estimator": + disp = LiftCurveDisplay.from_estimator(lr, X, y, **common_kwargs) + else: + disp = LiftCurveDisplay.from_predictions(y, y_pred, **common_kwargs) + + lift, percentages, _ = lift_curve( + y, + y_pred, + sample_weight=sample_weight, + pos_label=pos_label, + ) + + assert_allclose(disp.lift, lift) + assert_allclose(disp.percentages, percentages) + + assert disp.estimator_name == "LogisticRegression" + + # cannot fail thanks to pyplot fixture + import matplotlib as mpl # noqal + + assert isinstance(disp.line_, mpl.lines.Line2D) + assert disp.line_.get_alpha() == 0.8 + assert isinstance(disp.ax_, mpl.axes.Axes) + assert isinstance(disp.figure_, mpl.figure.Figure) + assert disp.line_.get_label() == "LogisticRegression" + + expected_pos_label = 1 if pos_label is None else pos_label + expected_ylabel = f"Lift (Positive label: {expected_pos_label})" + expected_xlabel = f"Positive Rate (Positive label: {expected_pos_label})" + assert disp.ax_.get_ylabel() == expected_ylabel + assert disp.ax_.get_xlabel() == expected_xlabel + + +@pytest.mark.parametrize( + "constructor_name, expected_clf_name", + [ + ("from_estimator", "LogisticRegression"), + ("from_predictions", "Classifier"), + ], +) +def test_lift_curve_display_default_name( + pyplot, + constructor_name, + expected_clf_name, +): + # Check the default name display in the figure when `name` is not provided + X, y = load_iris(return_X_y=True) + # Binarize the data with only the two first classes + X, y = X[y < 2], y[y < 2] + + lr = LogisticRegression().fit(X, y) + y_pred = lr.predict_proba(X)[:, 1] + + if constructor_name == "from_estimator": + disp = LiftCurveDisplay.from_estimator(lr, X, y) + else: + disp = LiftCurveDisplay.from_predictions(y, y_pred) + + assert disp.estimator_name == expected_clf_name + assert disp.line_.get_label() == expected_clf_name diff --git a/sklearn/metrics/_plot/tests/test_plot_lift_curve.py b/sklearn/metrics/_plot/tests/test_plot_lift_curve.py new file mode 100644 index 0000000000000..f29b1f956ad43 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_plot_lift_curve.py @@ -0,0 +1,84 @@ +# TODO: remove this file when plot_lift_curve will be deprecated in 1.2 +import pytest +import numpy as np +from numpy.testing import assert_allclose + +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression + +from sklearn.metrics import lift_curve +from sklearn.metrics import plot_lift_curve + + +@pytest.fixture(scope="module") +def data(): + return load_iris(return_X_y=True) + + +@pytest.fixture(scope="module") +def data_binary(data): + X, y = data + return X[y < 2], y[y < 2] + + +@pytest.mark.filterwarnings("ignore: Function plot_lift_curve is deprecated") +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +@pytest.mark.parametrize("with_sample_weight", [True, False]) +@pytest.mark.parametrize("with_strings", [True, False]) +def test_plot_lift_curve( + pyplot, response_method, data_binary, with_sample_weight, with_strings +): + X, y = data_binary + + pos_label = None + if with_strings: + y = np.array(["c", "b"])[y] + pos_label = "c" + + if with_sample_weight: + rng = np.random.RandomState(42) + sample_weight = rng.randint(1, 4, size=(X.shape[0])) + else: + sample_weight = None + + lr = LogisticRegression() + lr.fit(X, y) + + viz = plot_lift_curve( + lr, + X, + y, + alpha=0.8, + sample_weight=sample_weight, + ) + + y_pred = getattr(lr, response_method)(X) + if y_pred.ndim == 2: + y_pred = y_pred[:, 1] + + lift, percentages, _ = lift_curve( + y, + y_pred, + sample_weight=sample_weight, + pos_label=pos_label, + ) + + assert_allclose(viz.lift, lift) + assert_allclose(viz.percentages, percentages) + + assert viz.estimator_name == "LogisticRegression" + + # cannot fail thanks to pyplot fixture + import matplotlib as mpl # noqal + + assert isinstance(viz.line_, mpl.lines.Line2D) + assert viz.line_.get_alpha() == 0.8 + assert isinstance(viz.ax_, mpl.axes.Axes) + assert isinstance(viz.figure_, mpl.figure.Figure) + assert viz.line_.get_label() == "LogisticRegression" + + expected_pos_label = 1 if pos_label is None else pos_label + expected_ylabel = f"Lift (Positive label: {expected_pos_label})" + expected_xlabel = f"Positive Rate (Positive label: {expected_pos_label})" + assert viz.ax_.get_ylabel() == expected_ylabel + assert viz.ax_.get_xlabel() == expected_xlabel From 26c3ff3cafad7c62cececbc991e4cc4bb29c4ebd Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 14:05:22 +0200 Subject: [PATCH 13/31] added lift display functions to the documentation --- doc/modules/classes.rst | 2 ++ doc/visualizations.rst | 2 ++ 2 files changed, 4 insertions(+) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 6592db93cb1dc..e7882bfb5e726 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1125,6 +1125,7 @@ See the :ref:`visualizations` section of the user guide for further details. metrics.plot_confusion_matrix metrics.plot_det_curve + metrics.plot_lift_curve metrics.plot_precision_recall_curve metrics.plot_roc_curve @@ -1134,6 +1135,7 @@ See the :ref:`visualizations` section of the user guide for further details. metrics.ConfusionMatrixDisplay metrics.DetCurveDisplay + metrics.LiftCurveDisplay metrics.PrecisionRecallDisplay metrics.RocCurveDisplay calibration.CalibrationDisplay diff --git a/doc/visualizations.rst b/doc/visualizations.rst index dd2e1379c14d8..e738f7193c75b 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -87,6 +87,7 @@ Functions inspection.plot_partial_dependence metrics.plot_confusion_matrix metrics.plot_det_curve + metrics.plot_lift_curve metrics.plot_precision_recall_curve metrics.plot_roc_curve @@ -102,5 +103,6 @@ Display Objects inspection.PartialDependenceDisplay metrics.ConfusionMatrixDisplay metrics.DetCurveDisplay + metrics.LiftCurveDisplay metrics.PrecisionRecallDisplay metrics.RocCurveDisplay From bb7febfc21e4a8ce0c796e50dc5f5b5e62675cf4 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 15:56:16 +0200 Subject: [PATCH 14/31] extended documentation and added reference to the lift curve --- doc/modules/model_evaluation.rst | 38 ++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 371ee2c25e8d7..f34e64f6e765f 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -747,12 +747,12 @@ In the multilabel case with binary label indicators:: Lift ---- -`Lift `_ can be understood in -different ways. One way is as the ratio of the positive responses of a targeted -treatment of a subset of the dataset relative to the ratio of positive responses -in the dataset as a whole. +Lift [WikipediaLift2021]_ can be understood in different ways. One way is as +the ratio of the positive responses of a targeted treatment of a subset of the +dataset relative to the ratio of positive responses in the dataset as a whole. -Lift can also be understood as a kind of normalised precision of the positive class. +Lift can also be understood as a kind of normalised precision of the positive +class. .. math:: @@ -762,9 +762,13 @@ Lift can also be understood as a kind of normalised precision of the positive cl Lift = \frac{Precision}{pr} -where :math:`tp`, :math:`fp`, :math:`fn`, :math:`n` and :math:`pr` are the true positive count, false positive count, false negative count, dataset size and positive rate respectively. +where :math:`tp`, :math:`fp`, :math:`fn`, :math:`n` and :math:`pr` are the +true positive count, false positive count, false negative count, dataset size +and positive classification rate respectively. -Here is an example showing how to calculate lift:: +:func:`lift_score` in scikit-learn is an implimentation of lift. + +Here is an example showing how to calculate:: >>> from sklearn.metrics import lift_score >>> y_pred = [1, 1, 1, 1, 1, 2, 2, 2] @@ -772,6 +776,26 @@ Here is an example showing how to calculate lift:: >>> lift_score(y_true, y_pred) 1.2 +Related to the :func:`lift_score` is the :func:`lift_curve`. The lift curve +shows the lift on the y-axis relative to the percentage of the population's +positive classification rate (percentage of population classified as the +positive class) on the x-axis. + +Intuitively, the lift curve shows what is the precision/effectivness of a +treatment on a subset of the population as we increase the size of the subset, +all relative to the effectiveness of a random treatment. it is related closely +to the :func:`precision_recall_curve`. + +:class:`LiftCurveDisplay` can be used to visually represent a lift curve. See +:class:`LiftCurveDisplay` and :func:`lift_curve` for examples and instructions. + +.. topic:: References: + + .. [WikipediaLift2021] Wikipedia contributers. Lift (data mining). Wikipedia + October 13, 2021, 21:00 UTC. Available at: + https://en.wikipedia.org/wiki/Lift_(data_mining). + Accessed October 13, 2021. + .. _precision_recall_f_measure_metrics: Precision, recall and F-measures From aeecef13191cd34649a14ea11fbc94221b0b01f6 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 16:28:45 +0200 Subject: [PATCH 15/31] ran black for code formatting and committing changes --- sklearn/metrics/_ranking.py | 6 ++---- sklearn/metrics/tests/test_classification.py | 4 +--- sklearn/metrics/tests/test_ranking.py | 5 +++-- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 98455c1b8e9f4..95167e4f4afd3 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1082,16 +1082,14 @@ def lift_curve(y_true, y_score, *, pos_label=None, sample_weight=None): """ fps, tps, thresholds = _binary_clf_curve( - y_true, y_score, - pos_label=pos_label, - sample_weight=sample_weight + y_true, y_score, pos_label=pos_label, sample_weight=sample_weight ) # False negatives fns = tps[-1] - tps # Sample counts n_samples = fps[-1] + tps[-1] - + # Lift & percentages lift = n_samples * tps / ((fps + tps) * (fns + tps)) percentages = 100 * (fps + tps) / n_samples diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index c7da5bc8df917..a8dda8a35cc3a 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -467,9 +467,7 @@ def test_lift_score_sample_weight(): def test_lift_score_warning(): with pytest.warns(UndefinedMetricWarning): - lift_score( - [1, 1, 1], [0, 0, 0], zero_division="warn" - ) + lift_score([1, 1, 1], [0, 0, 0], zero_division="warn") def test_confusion_matrix_binary(): diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 534224f3daa49..fb3c6ab08f363 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -531,8 +531,9 @@ def test_lift_curve_sample_weight(): y_true = [0, 1, 0, 1, 1, 0, 1, 1] y_score = [0, 1, 0.5, 0.6, 0.4, 0.1, 0.7, 0.4] weights = [1, 1, 2, 2, 2, 0, 0, 4] - lift1, percentages1, thresholds1 = lift_curve(y_true, y_score, - sample_weight=weights) + lift1, percentages1, thresholds1 = lift_curve( + y_true, y_score, sample_weight=weights + ) # With repeats y_true = [0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] From 88f11c7197c0215311ce503c4acadc373073a271 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 16:37:36 +0200 Subject: [PATCH 16/31] Fixed bug after running flake8 --- sklearn/metrics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 03bfd97a6572f..0febd75e32747 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -135,7 +135,7 @@ "jaccard_score", "label_ranking_average_precision_score", "label_ranking_loss", - "liftCurveDisplay", + "LiftCurveDisplay", "lift_curve", "lift_score", "log_loss", From d3db8e7e3e043d6c2f80a999eab2ed0a3aa5d154 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 17:00:53 +0200 Subject: [PATCH 17/31] correct example in lift_curve function docs --- sklearn/metrics/_ranking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 95167e4f4afd3..3584feed6696a 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1074,7 +1074,7 @@ def lift_curve(y_true, y_score, *, pos_label=None, sample_weight=None): >>> scores = np.array([0.1, 0.4, 0.3, 0.8]) >>> lift, percentages, threshs = metrics.lift_curve(y, scores, pos_label=2) >>> lift - array([2. , 2. , 2. , 1.33333333 , 1.]) + array([2., 2., 2., 1.33333333, 1.]) >>> percentages array([0., 25., 50., 75., 100.]) >>> threshs From b018fbd6086e924d3830d2813f1e8662d29b0417 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 17:28:12 +0200 Subject: [PATCH 18/31] added changelog --- doc/whats_new/v1.1.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index a473908d8f1e7..4754069f96100 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -77,6 +77,14 @@ Changelog backward compatibility, but this alias will be removed in 1.3. :pr:`21177` by :user:`Julien Jerphanion `. +- |Feature| :func:`lift_score` to calculate lift score. + +- |Feature| :func:`lift_curve` to calculate lift curve. basicaly lift values for + different positive classification rates (percentage of data points classified + positive) + +- |Feature| :class:`LiftCurveDisplay` to plot the :func:`lift_curve`. + :mod:`sklearn.model_selection` .............................. From ceb51b3c9f09ce273c1807f2a9584bc18da5d7b1 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 17:34:25 +0200 Subject: [PATCH 19/31] added changelog refernce to pr --- doc/whats_new/v1.1.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 4754069f96100..bc5cc81d57043 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -77,13 +77,14 @@ Changelog backward compatibility, but this alias will be removed in 1.3. :pr:`21177` by :user:`Julien Jerphanion `. -- |Feature| :func:`lift_score` to calculate lift score. +- |Feature| :func:`lift_score` to calculate lift score. In :pr:`21320` by `Nawar Halabi`_. - |Feature| :func:`lift_curve` to calculate lift curve. basicaly lift values for different positive classification rates (percentage of data points classified - positive) + positive). In :pr:`21320` by `Nawar Halabi`_. - |Feature| :class:`LiftCurveDisplay` to plot the :func:`lift_curve`. + In :pr:`21320` by `Nawar Halabi`_. :mod:`sklearn.model_selection` .............................. From 68121b54eaa527c24439e219209673efd1ea86cf Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 17:45:01 +0200 Subject: [PATCH 20/31] corrected bug in whats news --- doc/whats_new/v1.1.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index bc5cc81d57043..ebdd07950bc7b 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -77,13 +77,13 @@ Changelog backward compatibility, but this alias will be removed in 1.3. :pr:`21177` by :user:`Julien Jerphanion `. -- |Feature| :func:`lift_score` to calculate lift score. In :pr:`21320` by `Nawar Halabi`_. +- |Feature| :func:`metrics.lift_score` to calculate lift score. In :pr:`21320` by `Nawar Halabi`_. -- |Feature| :func:`lift_curve` to calculate lift curve. basicaly lift values for +- |Feature| :func:`metrics.lift_curve` to calculate lift curve. basicaly lift values for different positive classification rates (percentage of data points classified positive). In :pr:`21320` by `Nawar Halabi`_. -- |Feature| :class:`LiftCurveDisplay` to plot the :func:`lift_curve`. +- |Feature| :class:`metrics.LiftCurveDisplay` to plot the :func:`metrics.lift_curve`. In :pr:`21320` by `Nawar Halabi`_. :mod:`sklearn.model_selection` From 65684687cbb608dff530f9534817e95014575cc0 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 17:51:53 +0200 Subject: [PATCH 21/31] fixed bug in whats new --- doc/whats_new/v1.1.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index ebdd07950bc7b..80295b7268ef1 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -77,14 +77,14 @@ Changelog backward compatibility, but this alias will be removed in 1.3. :pr:`21177` by :user:`Julien Jerphanion `. -- |Feature| :func:`metrics.lift_score` to calculate lift score. In :pr:`21320` by `Nawar Halabi`_. +- |Feature| :func:`metrics.lift_score` to calculate lift score. In :pr:`21320` by :user:`Nawar Halabi `. - |Feature| :func:`metrics.lift_curve` to calculate lift curve. basicaly lift values for different positive classification rates (percentage of data points classified - positive). In :pr:`21320` by `Nawar Halabi`_. + positive). In :pr:`21320` by :user:`Nawar Halabi `. - |Feature| :class:`metrics.LiftCurveDisplay` to plot the :func:`metrics.lift_curve`. - In :pr:`21320` by `Nawar Halabi`_. + In :pr:`21320` by :user:`Nawar Halabi `. :mod:`sklearn.model_selection` .............................. From f31bbf0d416272e08429da58e21696f5cca523a4 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Wed, 13 Oct 2021 18:19:05 +0200 Subject: [PATCH 22/31] fixed example of lift_curve not matching output value in doctest --- sklearn/metrics/_ranking.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 3584feed6696a..4119220c99b6e 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -1074,11 +1074,11 @@ def lift_curve(y_true, y_score, *, pos_label=None, sample_weight=None): >>> scores = np.array([0.1, 0.4, 0.3, 0.8]) >>> lift, percentages, threshs = metrics.lift_curve(y, scores, pos_label=2) >>> lift - array([2., 2., 2., 1.33333333, 1.]) + array([2. , 2. , 2. , 1.33333333, 1. ]) >>> percentages - array([0., 25., 50., 75., 100.]) + array([ 0., 25., 50., 75., 100.]) >>> threshs - array([1.8, 0.8, 0.4, 0.3, 0.1] + array([1.8, 0.8, 0.4, 0.3, 0.1]) """ fps, tps, thresholds = _binary_clf_curve( From eb1235715a55ce2bf43284f9e8e8677928a0755d Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Sun, 17 Oct 2021 20:48:20 +0200 Subject: [PATCH 23/31] fixed documentation according to new standards --- sklearn/metrics/_classification.py | 7 +++++-- sklearn/metrics/_plot/lift_curve.py | 10 +++++----- sklearn/metrics/_plot/precision_recall_curve.py | 1 + 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 985abdebbd70f..2715f8ee041da 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1796,12 +1796,15 @@ def lift_score( ------- lift : float (if average is not None) or array of float of shape \ (n_unique_labels,) - lift of the positive class in binary classification or weighted + Lift of the positive class in binary classification or weighted average of the lift of each class for the multiclass task. See Also -------- - lift_curve, precision_recall_curve + lift_curve : Calculate the lift for different positive rates. + + precision_recall_curve: Calculate precision and recall for different + classification thresholds. Notes ----- diff --git a/sklearn/metrics/_plot/lift_curve.py b/sklearn/metrics/_plot/lift_curve.py index 7fa3b3c06b19a..3ffd0775f95cf 100644 --- a/sklearn/metrics/_plot/lift_curve.py +++ b/sklearn/metrics/_plot/lift_curve.py @@ -358,16 +358,16 @@ def plot_lift_curve( ): """Plot lift curve. - Extra keyword arguments will be passed to matplotlib's `plot`. - - Read more in the :ref:`User Guide `. - .. deprecated:: 1.0 `plot_lift_curve` is deprecated in 1.0 and will be removed in 1.2. Use one of the following class methods: :func:`~sklearn.metrics.LiftCurveDisplay.from_predictions` or :func:`~sklearn.metrics.LiftCurveDisplay.from_estimator`. + Extra keyword arguments will be passed to matplotlib's `plot`. + + Read more in the :ref:`User Guide `. + Parameters ---------- estimator : estimator instance @@ -414,7 +414,7 @@ def plot_lift_curve( -------- lift_curve : Compute lift scores for different treatment percentages (percent of positively classified data points). - LiftCurveDisplay : lift curve visualization. + LiftCurveDisplay : Lift curve visualization. LiftCurveDisplay.from_estimator : Plot lift curve given an estimator and some data. LiftCurveDisplay.from_predictions : Plot lift curve given the true and diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index eaf8240062174..db87e44b73893 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -337,6 +337,7 @@ def from_predictions( name = name if name is not None else "Classifier" + viz = PrecisionRecallDisplay( precision=precision, recall=recall, From 47607960774e0faed66b66c86913d9a0876cc7ba Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Sun, 17 Oct 2021 21:22:25 +0200 Subject: [PATCH 24/31] fixed doc bug added in last commit by accendent in presison_recall_curve --- sklearn/metrics/_plot/precision_recall_curve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index db87e44b73893..eaf8240062174 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -337,7 +337,6 @@ def from_predictions( name = name if name is not None else "Classifier" - viz = PrecisionRecallDisplay( precision=precision, recall=recall, From 597cc76dc19e12c9dce4d33a55cd21c0d8d05779 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Sun, 17 Oct 2021 22:53:42 +0200 Subject: [PATCH 25/31] adjusted doc of lift_curve function to match standards --- sklearn/metrics/_plot/lift_curve.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_plot/lift_curve.py b/sklearn/metrics/_plot/lift_curve.py index 3ffd0775f95cf..ab1408e5df3ee 100644 --- a/sklearn/metrics/_plot/lift_curve.py +++ b/sklearn/metrics/_plot/lift_curve.py @@ -358,16 +358,16 @@ def plot_lift_curve( ): """Plot lift curve. + Extra keyword arguments will be passed to matplotlib's `plot`. + + Read more in the :ref:`User Guide `. + .. deprecated:: 1.0 `plot_lift_curve` is deprecated in 1.0 and will be removed in 1.2. Use one of the following class methods: :func:`~sklearn.metrics.LiftCurveDisplay.from_predictions` or :func:`~sklearn.metrics.LiftCurveDisplay.from_estimator`. - Extra keyword arguments will be passed to matplotlib's `plot`. - - Read more in the :ref:`User Guide `. - Parameters ---------- estimator : estimator instance From 783346bc767959ecd9f88db4b1dfe0ff48799bf1 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Sun, 17 Oct 2021 23:54:22 +0200 Subject: [PATCH 26/31] Added more words to lift_curve plotting function's docs description to explain more --- sklearn/metrics/_plot/lift_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/lift_curve.py b/sklearn/metrics/_plot/lift_curve.py index ab1408e5df3ee..a0e8affce4be9 100644 --- a/sklearn/metrics/_plot/lift_curve.py +++ b/sklearn/metrics/_plot/lift_curve.py @@ -356,7 +356,7 @@ def plot_lift_curve( pos_label=None, **kwargs, ): - """Plot lift curve. + """Plot the lift curve for binary classifiers. Extra keyword arguments will be passed to matplotlib's `plot`. From 6ce74039b134adea7f8a5ffae3d856e688eef580 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 18 Oct 2021 00:14:51 +0200 Subject: [PATCH 27/31] added ignoring docstring testing which is newly added. Fails for no clear reason momentarily. To be changed in the future --- maint_tools/test_docstrings.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/maint_tools/test_docstrings.py b/maint_tools/test_docstrings.py index f52bbb70e5e99..00d4ac6cd9cb5 100644 --- a/maint_tools/test_docstrings.py +++ b/maint_tools/test_docstrings.py @@ -127,6 +127,7 @@ "sklearn.metrics._classification.hamming_loss", "sklearn.metrics._classification.hinge_loss", "sklearn.metrics._classification.jaccard_score", + "sklearn.metrics._classification.lift_score", "sklearn.metrics._classification.log_loss", "sklearn.metrics._classification.precision_recall_fscore_support", "sklearn.metrics._classification.precision_score", @@ -134,6 +135,7 @@ "sklearn.metrics._classification.zero_one_loss", "sklearn.metrics._plot.confusion_matrix.plot_confusion_matrix", "sklearn.metrics._plot.det_curve.plot_det_curve", + "sklearn.metrics._plot.det_curve.plot_lift_curve", "sklearn.metrics._plot.precision_recall_curve.plot_precision_recall_curve", "sklearn.metrics._plot.roc_curve.plot_roc_curve", "sklearn.metrics._ranking.auc", @@ -142,6 +144,7 @@ "sklearn.metrics._ranking.dcg_score", "sklearn.metrics._ranking.label_ranking_average_precision_score", "sklearn.metrics._ranking.label_ranking_loss", + "sklearn.metrics._ranking.lift_curve", "sklearn.metrics._ranking.ndcg_score", "sklearn.metrics._ranking.precision_recall_curve", "sklearn.metrics._ranking.roc_auc_score", From 03d5cb32627031b5a292e8be072cda735d552eb6 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 18 Oct 2021 00:25:18 +0200 Subject: [PATCH 28/31] fixed plot_lift_curve function to comply with docstrings --- sklearn/metrics/_plot/lift_curve.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sklearn/metrics/_plot/lift_curve.py b/sklearn/metrics/_plot/lift_curve.py index a0e8affce4be9..9dea3ec6f989d 100644 --- a/sklearn/metrics/_plot/lift_curve.py +++ b/sklearn/metrics/_plot/lift_curve.py @@ -362,12 +362,6 @@ def plot_lift_curve( Read more in the :ref:`User Guide `. - .. deprecated:: 1.0 - `plot_lift_curve` is deprecated in 1.0 and will be removed in - 1.2. Use one of the following class methods: - :func:`~sklearn.metrics.LiftCurveDisplay.from_predictions` or - :func:`~sklearn.metrics.LiftCurveDisplay.from_estimator`. - Parameters ---------- estimator : estimator instance From 4f296ffdd6f24cbad238a136359150400ed2a645 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Mon, 18 Oct 2021 00:27:25 +0200 Subject: [PATCH 29/31] fixed docstrings of lift_cruve and lift_scire --- maint_tools/test_docstrings.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/maint_tools/test_docstrings.py b/maint_tools/test_docstrings.py index 00d4ac6cd9cb5..f52bbb70e5e99 100644 --- a/maint_tools/test_docstrings.py +++ b/maint_tools/test_docstrings.py @@ -127,7 +127,6 @@ "sklearn.metrics._classification.hamming_loss", "sklearn.metrics._classification.hinge_loss", "sklearn.metrics._classification.jaccard_score", - "sklearn.metrics._classification.lift_score", "sklearn.metrics._classification.log_loss", "sklearn.metrics._classification.precision_recall_fscore_support", "sklearn.metrics._classification.precision_score", @@ -135,7 +134,6 @@ "sklearn.metrics._classification.zero_one_loss", "sklearn.metrics._plot.confusion_matrix.plot_confusion_matrix", "sklearn.metrics._plot.det_curve.plot_det_curve", - "sklearn.metrics._plot.det_curve.plot_lift_curve", "sklearn.metrics._plot.precision_recall_curve.plot_precision_recall_curve", "sklearn.metrics._plot.roc_curve.plot_roc_curve", "sklearn.metrics._ranking.auc", @@ -144,7 +142,6 @@ "sklearn.metrics._ranking.dcg_score", "sklearn.metrics._ranking.label_ranking_average_precision_score", "sklearn.metrics._ranking.label_ranking_loss", - "sklearn.metrics._ranking.lift_curve", "sklearn.metrics._ranking.ndcg_score", "sklearn.metrics._ranking.precision_recall_curve", "sklearn.metrics._ranking.roc_auc_score", From cac14e65a18f16b529de95fee074f71308bacfa8 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Thu, 11 Nov 2021 12:30:36 +0100 Subject: [PATCH 30/31] fixed documentation missing new line in v1.1 --- doc/whats_new/v1.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 927898582fc19..d4b877186fe4a 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -147,6 +147,7 @@ Changelog - |Feature| :class:`metrics.LiftCurveDisplay` to plot the :func:`metrics.lift_curve`. In :pr:`21320` by :user:`Nawar Halabi `. + :mod:`sklearn.manifold` ....................... From a0508c3dc9528b46d506975b50dda4390270f487 Mon Sep 17 00:00:00 2001 From: Nawar Halabi Date: Thu, 25 Nov 2021 16:50:35 +0100 Subject: [PATCH 31/31] added online references for the lift_score function --- sklearn/metrics/_classification.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index ff7e992d8144b..17efa065dad15 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1833,6 +1833,15 @@ def lift_score( function. Lift is only possible on binary data. + References + ---------- + .. [1] `Wikipedia entry for lift + `_. + .. [2] `Example of lift in parctice + `_. + .. [3] `Life curve in machine learning + `_. + Examples -------- >>> from sklearn.metrics import lift_score