Skip to content

ENH Array API support for f1_score and multilabel_confusion_matrix #27369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Nov 25, 2024

Conversation

OmarManzoor
Copy link
Contributor

@OmarManzoor OmarManzoor commented Sep 14, 2023

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

Any other comments?

CC: @ogrisel @betatim

@github-actions
Copy link

github-actions bot commented Sep 14, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 88a48e0. Link to the linter CI: here

@OmarManzoor OmarManzoor marked this pull request as ready for review May 20, 2024 10:07
@OmarManzoor
Copy link
Contributor Author

@ogrisel Could you kindly have a look at this PR?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good to me. I am surprised it works without being very specific about device and dtypes, but as long as the tests (and they do), I am happy.

tp = np.array(tp)
fp = np.array(fp)
fn = np.array(fn)
sample_weight = xp.asarray(sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make sure that it matches the device of the inputs, no? It's curious that existing tests do not fail with PyTorch and MPS device (or cuda devices).

I am also wondering of whether we should convert to a specific dtype. However looking at the tests I never see any case where we pass non-integer sample weights. And even for integer weights, it's only done to check an error message, not to check an actual computation. So I am not sure our sample_weight support is correct, even outside of array API concerns.

I guess this is only indirectly tested by classification metrics that rely on multilabel_confusion_matrix internally. But then the array API compliance tests for F1 score do not fail with floating point weights (I just checked) and I am not sure why.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the output of my cuda run on this PR (updated to check that boolean array indexing also works, but this should be orthogonal):

$ pytest -vlx -k "array_api and f1_score" sklearn/ 

================================================================================================== test session starts ===================================================================================================
platform linux -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0
collected 34881 items / 34863 deselected / 2 skipped / 18 selected                                                                                                                                                       

sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-numpy-None-None] PASSED                                                                      [  5%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-array_api_strict-None-None] PASSED                                                           [ 11%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy-None-None] PASSED                                                                       [ 16%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy.array_api-None-None] PASSED                                                             [ 22%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float64] PASSED                                                                    [ 27%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float32] PASSED                                                                    [ 33%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float64] PASSED                                                                   [ 38%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float32] PASSED                                                                   [ 44%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK...) [ 50%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-numpy-None-None] PASSED                                                                  [ 55%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-array_api_strict-None-None] PASSED                                                       [ 61%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy-None-None] PASSED                                                                   [ 66%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy.array_api-None-None] PASSED                                                         [ 72%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float64] PASSED                                                                [ 77%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float32] PASSED                                                                [ 83%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float64] PASSED                                                               [ 88%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float32] PASSED                                                               [ 94%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALL...) [100%]

============================================================================= 16 passed, 4 skipped, 34863 deselected, 105 warnings in 15.59s =============================================================================

@OmarManzoor
Copy link
Contributor Author

@ogrisel @betatim Does this look okay now?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not have time to finish the review today but here is some quick feedback:

@ogrisel
Copy link
Member

ogrisel commented Jun 5, 2024

I merged main to be able to launch the new CUDA GPU CI workflow on this PR:

EDIT: tests are green.

@OmarManzoor
Copy link
Contributor Author

@adrinjalali Could you kindly have a look at this PR now?

Comment on lines 642 to 645
if _is_numpy_namespace(xp=xp):
true_and_pred = y_true.multiply(y_pred)
else:
true_and_pred = xp.multiply(y_true, y_pred)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first branch is the case where we are multiplying two sparse matrices together.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we should check for sparse, not numpy namespace. This looks to me that we'd be using that branch for np.ndarray

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Nov 12, 2024

@adrinjalali Does this look okay now?

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise LGTM.

precision = _nanaverage(precision, weights=weights)
recall = _nanaverage(recall, weights=weights)
f_score = _nanaverage(f_score, weights=weights)
assert average != "binary" or precision.shape[0] == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay to leave this as is in this PR with this since it's existing code, but we really shouldn't be assert ing here. If this never happens, then the line shouldn't be here, if it can happen, we should raise a meaningful error.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the version question, LGTM.

`device` object.
See the :ref:`Array API User Guide <array_api>` for more details.

.. versionadded:: 1.6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glemaitre @jeremiedbb should we change this or backport?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's backport since we have the experimental guardrail

@glemaitre glemaitre merged commit 96b53ad into scikit-learn:main Nov 25, 2024
30 checks passed
@glemaitre glemaitre added the To backport PR merged in master that need a backport to a release branch defined based on the milestone. label Nov 25, 2024
@glemaitre glemaitre added this to the 1.6 milestone Nov 25, 2024
@glemaitre
Copy link
Member

Thanks @OmarManzoor

@OmarManzoor OmarManzoor deleted the f1_array_api branch November 25, 2024 09:50
jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Dec 4, 2024
jeremiedbb pushed a commit that referenced this pull request Dec 6, 2024
…27369)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
virchan pushed a commit to virchan/scikit-learn that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Array API module:metrics module:preprocessing module:utils To backport PR merged in master that need a backport to a release branch defined based on the milestone. Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

5 participants