-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Array API support for f1_score and multilabel_confusion_matrix #27369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@ogrisel Could you kindly have a look at this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good to me. I am surprised it works without being very specific about device and dtypes, but as long as the tests (and they do), I am happy.
sklearn/metrics/_classification.py
Outdated
tp = np.array(tp) | ||
fp = np.array(fp) | ||
fn = np.array(fn) | ||
sample_weight = xp.asarray(sample_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably make sure that it matches the device of the inputs, no? It's curious that existing tests do not fail with PyTorch and MPS device (or cuda devices).
I am also wondering of whether we should convert to a specific dtype. However looking at the tests I never see any case where we pass non-integer sample weights. And even for integer weights, it's only done to check an error message, not to check an actual computation. So I am not sure our sample_weight
support is correct, even outside of array API concerns.
I guess this is only indirectly tested by classification metrics that rely on multilabel_confusion_matrix
internally. But then the array API compliance tests for F1 score do not fail with floating point weights (I just checked) and I am not sure why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the output of my cuda run on this PR (updated to check that boolean array indexing also works, but this should be orthogonal):
$ pytest -vlx -k "array_api and f1_score" sklearn/
================================================================================================== test session starts ===================================================================================================
platform linux -- Python 3.10.12, pytest-7.4.2, pluggy-1.3.0
collected 34881 items / 34863 deselected / 2 skipped / 18 selected
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-numpy-None-None] PASSED [ 5%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-array_api_strict-None-None] PASSED [ 11%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy-None-None] PASSED [ 16%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-cupy.array_api-None-None] PASSED [ 22%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float64] PASSED [ 27%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cpu-float32] PASSED [ 33%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float64] PASSED [ 38%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-cuda-float32] PASSED [ 44%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_binary_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK...) [ 50%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-numpy-None-None] PASSED [ 55%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-array_api_strict-None-None] PASSED [ 61%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy-None-None] PASSED [ 66%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-cupy.array_api-None-None] PASSED [ 72%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float64] PASSED [ 77%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cpu-float32] PASSED [ 83%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float64] PASSED [ 88%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-cuda-float32] PASSED [ 94%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[f1_score-check_array_api_multiclass_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALL...) [100%]
============================================================================= 16 passed, 4 skipped, 34863 deselected, 105 warnings in 15.59s =============================================================================
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not have time to finish the review today but here is some quick feedback:
I merged EDIT: tests are green. |
@adrinjalali Could you kindly have a look at this PR now? |
sklearn/metrics/_classification.py
Outdated
if _is_numpy_namespace(xp=xp): | ||
true_and_pred = y_true.multiply(y_pred) | ||
else: | ||
true_and_pred = xp.multiply(y_true, y_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first branch is the case where we are multiplying two sparse matrices together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we should check for sparse, not numpy namespace. This looks to me that we'd be using that branch for np.ndarray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@adrinjalali Does this look okay now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise LGTM.
precision = _nanaverage(precision, weights=weights) | ||
recall = _nanaverage(recall, weights=weights) | ||
f_score = _nanaverage(f_score, weights=weights) | ||
assert average != "binary" or precision.shape[0] == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay to leave this as is in this PR with this since it's existing code, but we really shouldn't be assert
ing here. If this never happens, then the line shouldn't be here, if it can happen, we should raise a meaningful error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the version question, LGTM.
`device` object. | ||
See the :ref:`Array API User Guide <array_api>` for more details. | ||
|
||
.. versionadded:: 1.6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glemaitre @jeremiedbb should we change this or backport?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's backport since we have the experimental guardrail
Thanks @OmarManzoor |
…cikit-learn#27369) Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
…27369) Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
…cikit-learn#27369) Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Reference Issues/PRs
Towards #26024
What does this implement/fix? Explain your changes.
Any other comments?
CC: @ogrisel @betatim