-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Array API for check_consistent_length
#29519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
with pytest.raises(TypeError): | ||
check_consistent_length([1, 2], np.array(1)) | ||
check_consistent_length(xp.asarray([1, 2], device=device), np.array(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this wouldn't raise when xp==np
, would it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just tried it. I raises with TypeError: Singleton array array(1) cannot be considered a valid collection.
from within _num_samples()
in sklearn/utils/validation.py:376.
The last four pytest.raises checks are not really testing check_consistent_length()
, but some other error mechanism outside of it. It feels not very clean. What would be a good way to handle this? Separate those tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your questions @adrinjalali! I have answered them and maybe we can refactor the tests to make these things more obvious.
|
||
with pytest.raises(TypeError): | ||
check_consistent_length([1, 2], np.array(1)) | ||
check_consistent_length(xp.asarray([1, 2], device=device), np.array(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just tried it. I raises with TypeError: Singleton array array(1) cannot be considered a valid collection.
from within _num_samples()
in sklearn/utils/validation.py:376.
The last four pytest.raises checks are not really testing check_consistent_length()
, but some other error mechanism outside of it. It feels not very clean. What would be a good way to handle this? Separate those tests?
Also, I don't know why the CI tests "scikit-learn.scikit-learn Expected — Waiting for status to be reported" don't run in my PR? |
I've fixed the PR, @adrinjalali, though now there is another issue:
Which comes from within Then, the sparse multilabel Do I assume right that the PRs that will fix #29452 will also fix this here? |
Hi both :) if you are interested in the discussion around that comment, please see #29336 and #29321 (comment) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR.
My main suggestion would be to rather keep test_check_consistent_length
unchanged to make sure that adding array API support does not introduce any behavioral change: in particular, we still want to test that check_consistent_length
can work on Python lists, including nested lists with in-homogeneous lengths such as [[1, 2], [[1, 2]]]
for which no equivalent NumPy array exist.
Instead of changing the existing test_check_consistent_length
function, I would suggest to add a new test named test_check_consistent_length
with the yield_namespace_device_dtype_combinations
parametrization to check that calling check_consistent_length
on valid array API inputs under the sklearn.config_context(array_api_dispatch=True)
context.
Do I assume right that the PRs that will fix #29452 will also fix this here?
Indeed, the TypeError: csr_matrix is not a supported array type
failures shall be addressed by #29476 which I am concurrently reviewing.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Thank you, @ogrisel! I have refactored the tests as you had suggested (though ruff didn't allow me to have them both have the same function name). Please let me know if there is anything else I should change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is another pass of feedback:
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Thanks for your review and the suggestions, @ogrisel. I have committed what you had suggested and the tests (including the cuda test on Colab) all pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @StefanieSenger! LGTM with small details.
@@ -1184,7 +1184,7 @@ def fowlkes_mallows_score(labels_true, labels_pred, *, sparse=False): | |||
|
|||
.. versionadded:: 0.18 | |||
|
|||
The Fowlkes-Mallows index (FMI) is defined as the geometric mean between of | |||
The Fowlkes-Mallows index (FMI) is defined as the geometric mean of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the change, but this seems unrelated to the scope of the PR. Maybe next time open small quick dedicated PRs for such quick documentation fixes.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Merging, thanks! |
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Reference Issues/PRs
towards #26024
What does this implement/fix? Explain your changes.
This PR adds Array API for
check_consistent_length
.Some little doc enhancements/typo corrections along the way.
I didn't run the MPS test.