Skip to content

FEA Implementation of "threshold-dependent metric per threshold value" curve #25639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from

Conversation

vitaliset
Copy link
Contributor

@vitaliset vitaliset commented Feb 18, 2023

Towards #21391.

Intending to later build the MetricThresholdCurveDisplay following the same structure that other Displays have, this PR implements the associate curve. I decided to break the original issue into two parts (curve and Display) for easier review (but I don't mind adding the Display to this PR as well).

[Update 08 June 2024] The code example is outdated after Guillaume Lemaitre's first reviews. For instance, I've moved the code to metrics instead of inspection and changed the parameters names. Leaving it here because the idea is still similar. Will update this later.

A quick example of usage of the implementation here:

import matplotlib.pyplot as plt
from imblearn.datasets import fetch_datasets
from sklearn.inspection import metric_threshold_curve
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import fbeta_score
from functools import partial

dataset = fetch_datasets()["coil_2000"]
X, y = dataset.data, (dataset.target==1).astype(int)

X_train_model, X_test, y_train_model, y_test = train_test_split(X, y, random_state=0, stratify=y)

model = RandomForestClassifier(random_state=0).fit(X_train_model, y_train_model)
predict_proba = model.predict_proba(X_test)[:, 1]

f2_values, thresholds = metric_threshold_curve(
    y_test, predict_proba, partial(fbeta_score, beta=2), threshold_grid=500
)

fig, ax = plt.subplots(figsize=(5, 2.4))
ax.plot(thresholds, f2_values)
ax.set_xlabel("thresholds")
ax.set_ylabel("f2 score")
plt.tight_layout()

image

[Update 08 June 2024] Will be using the code from _CurveScorer instead of _binary_clf_curve as soon as I move it to metrics module.

Most of the code for metric_threshold_curve function is an adaptation of _binary_clf_curve.

Points of doubt:

  • I thought the inspection module would be suitable for this type of analysis, but it is not 100% clear to me that this curve (and then the Display) should go here - other current options would be metrics or model_selection._prediction (just like the related meta-estimator from [WIP] FEA New meta-estimator to post-tune the decision_function/predict_proba threshold for binary classifiers #16525).
  • I would appreciate some help with test ideas!
  • I made the first version of the documentation to go along with the function. It's preliminary. I wanted to go into it only a little while we don't define how the function will look like. Ideas for it would be appreciated as well! :)

[Update 09 June 2024] When I come back to this PR, I need to do this before asking for a new review:

  • Add the code into the file names decided on MNT Moving _CurveScorer from model_selection to metrics #29216.
  • Refactor _CurveScorer so I can dissociate getting y_score from the scoring itself such that here we only call the scoring part.
  • Update the decision_threshold_curve code to use this new method of _CurveScorer.
  • Bring back the documentation using the new pydata-sphinx-theme. Check commit 4fab2a3 for reference.

Copy link

github-actions bot commented May 20, 2024

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


ruff

ruff detected issues. Please run ruff check --fix --output-format=full . locally, fix the remaining issues, and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.5.1.


sklearn/metrics/__init__.py:6:1: I001 [*] Import block is un-sorted or un-formatted
    |
  4 |   # SPDX-License-Identifier: BSD-3-Clause
  5 |   
  6 | / from . import cluster
  7 | | from ._classification import (
  8 | |     accuracy_score,
  9 | |     balanced_accuracy_score,
 10 | |     brier_score_loss,
 11 | |     class_likelihood_ratios,
 12 | |     classification_report,
 13 | |     cohen_kappa_score,
 14 | |     confusion_matrix,
 15 | |     d2_log_loss_score,
 16 | |     f1_score,
 17 | |     fbeta_score,
 18 | |     hamming_loss,
 19 | |     hinge_loss,
 20 | |     jaccard_score,
 21 | |     log_loss,
 22 | |     matthews_corrcoef,
 23 | |     multilabel_confusion_matrix,
 24 | |     precision_recall_fscore_support,
 25 | |     precision_score,
 26 | |     recall_score,
 27 | |     zero_one_loss,
 28 | | )
 29 | | from ._dist_metrics import DistanceMetric
 30 | | from ._plot.confusion_matrix import ConfusionMatrixDisplay
 31 | | from ._plot.det_curve import DetCurveDisplay
 32 | | from ._plot.precision_recall_curve import PrecisionRecallDisplay
 33 | | from ._plot.regression import PredictionErrorDisplay
 34 | | from ._plot.roc_curve import RocCurveDisplay
 35 | | from ._ranking import (
 36 | |     auc,
 37 | |     average_precision_score,
 38 | |     coverage_error,
 39 | |     dcg_score,
 40 | |     det_curve,
 41 | |     label_ranking_average_precision_score,
 42 | |     label_ranking_loss,
 43 | |     ndcg_score,
 44 | |     precision_recall_curve,
 45 | |     roc_auc_score,
 46 | |     roc_curve,
 47 | |     top_k_accuracy_score,
 48 | | )
 49 | | from ._regression import (
 50 | |     d2_absolute_error_score,
 51 | |     d2_pinball_score,
 52 | |     d2_tweedie_score,
 53 | |     explained_variance_score,
 54 | |     max_error,
 55 | |     mean_absolute_error,
 56 | |     mean_absolute_percentage_error,
 57 | |     mean_gamma_deviance,
 58 | |     mean_pinball_loss,
 59 | |     mean_poisson_deviance,
 60 | |     mean_squared_error,
 61 | |     mean_squared_log_error,
 62 | |     mean_tweedie_deviance,
 63 | |     median_absolute_error,
 64 | |     r2_score,
 65 | |     root_mean_squared_error,
 66 | |     root_mean_squared_log_error,
 67 | | )
 68 | | from ._scorer import check_scoring, get_scorer, get_scorer_names, make_scorer
 69 | | from ._decision_threshold import decision_threshold_curve
 70 | | from .cluster import (
 71 | |     adjusted_mutual_info_score,
 72 | |     adjusted_rand_score,
 73 | |     calinski_harabasz_score,
 74 | |     completeness_score,
 75 | |     consensus_score,
 76 | |     davies_bouldin_score,
 77 | |     fowlkes_mallows_score,
 78 | |     homogeneity_completeness_v_measure,
 79 | |     homogeneity_score,
 80 | |     mutual_info_score,
 81 | |     normalized_mutual_info_score,
 82 | |     pair_confusion_matrix,
 83 | |     rand_score,
 84 | |     silhouette_samples,
 85 | |     silhouette_score,
 86 | |     v_measure_score,
 87 | | )
 88 | | from .pairwise import (
 89 | |     euclidean_distances,
 90 | |     nan_euclidean_distances,
 91 | |     pairwise_distances,
 92 | |     pairwise_distances_argmin,
 93 | |     pairwise_distances_argmin_min,
 94 | |     pairwise_distances_chunked,
 95 | |     pairwise_kernels,
 96 | | )
 97 | | 
 98 | | __all__ = [
    | |_^ I001
 99 |       "accuracy_score",
100 |       "adjusted_mutual_info_score",
    |
    = help: Organize imports

Found 1 error.
[*] 1 fixable with the `--fix` option.

Generated for commit: a424c3e. Link to the linter CI: here

@glemaitre
Copy link
Member

OK now that we merged the FixedThresholdClassifier and TunedThresholdClassifierCV, it gives me another perspective on the tool.

I think this is time to review and prioritize this feature.

@vitaliset would you have time to dedicate to work on this feature?

@glemaitre glemaitre added this to the 1.6 milestone May 20, 2024
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial thoughts. I did not look at the documentation or test but it will come later.

def metric_threshold_curve(
y_true,
y_score,
score_func,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to call this scoring as well for consistency. But here, we should only accept a callable.

Comment on lines 130 to 177
# Make y_true a boolean vector.
y_true = y_true == pos_label

# Sort scores and corresponding truth values.
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
y_score = y_score[desc_score_indices]
y_true = y_true[desc_score_indices]
if sample_weight is not None:
sample_weight = sample_weight[desc_score_indices]

# Logic to see if we need to use all possible thresholds (distinct values).
all_thresholds = False
if threshold_grid is None:
all_thresholds = True
elif isinstance(threshold_grid, int):
if len(set(y_score)) < threshold_grid:
all_thresholds = True

if all_thresholds:
# y_score typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
thresholds = y_score[threshold_idxs[::-1]]
elif isinstance(threshold_grid, int):
# It takes representative score points to calculate the metric
# with these thresholds.
thresholds = np.percentile(
list(set(y_score)), np.linspace(0, 100, threshold_grid)
)
else:
# If threshold_grid is an array then run some checks and sort
# it for consistency.
threshold_grid = column_or_1d(threshold_grid)
assert_all_finite(threshold_grid)
thresholds = np.sort(threshold_grid)

# For each threshold calculates the metric.
metric_values = []
for threshold in thresholds:
preds_threshold = (y_score > threshold).astype(int)
metric_values.append(
score_func(y_true, preds_threshold, sample_weight=sample_weight)
)
# TODO: should we multithread the metric calculations?

return np.array(metric_values), thresholds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this code is already implemented by the _score function of the sklearn.model_selection._classification_threshold._CurveScorer class.

I think that we should leverage this code by creating this scorer. We probably need to dissociate getting y_score from the scoring itself such that here we only call the scoring part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now, it makes sense to move the _CurveScorer in metrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now, it makes sense to move the _CurveScorer in metrics.

Do you want me to do this in this PR or create a separate one?

Copy link
Member

@glemaitre glemaitre May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to be in a separate PR. Depending on the schedule, I might start to do the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to dissociate getting y_score from the scoring itself such that here we only call the scoring part.

I think it will make sense to do this on this PR after we move the code to the proper place. Do you agree?

@vitaliset
Copy link
Contributor Author

I think this is time to review and prioritize this feature.

@vitaliset would you have time to dedicate to work on this feature?

Awesome news! I might need a couple of weeks, but I would love to make this feature available! Will work on your comments as soon as I can, @glemaitre.

@glemaitre glemaitre self-requested a review August 2, 2024 09:20
@glemaitre
Copy link
Member

It will be tight to get in 1.6 but it will be one of my prioritize PR for 1.7.

@glemaitre glemaitre modified the milestones: 1.6, 1.7 Oct 29, 2024
@lucyleeow
Copy link
Member

@vitaliset , thanks for your patience on this. We discussed with @glemaitre and wanted to try and prioritize this. Are you still interested in working on this?

If not, I am happy to push it forward, and you will still be credited.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants