-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX Fix array API train_test_split
#28407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Fix array API train_test_split
#28407
Conversation
5d062ee
to
126c434
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite like this. Thanks @betatim , but I'm a bit worried about how complicated things become in methods that used to seem quite short and easy to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to look at the train_test_split
specific part of the PR but I did not find any non-regression test for this part.
Also, I assume that we might have a similar problem in cross_val_score
, cross_validate
and in the *SearchCV
meta estimators.
126c434
to
a796f33
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once the above comments are addressed.
Implemented |
I didn't add one because |
... when running the tests with cupy installed, which is not yet the case on our current CI (see: #24491). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a new test to explicitly cover the missing lines in _determine_key_type
but I think we should also add a test for train_test_split
itself that checks the that returned arrays are actually of the expected container type and device.
EDIT: those tests already exist but I don't understand why codecov was complaining then... I checked that they are not always skipped on my local laptop...
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cpu-float64] PASSED [ 3%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-mps-float32] PASSED [ 7%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-cupy.array_api-None-None] SKIPPED (cupy.array_api is not installed: not checking array_api input) [ 11%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-mps-float32] PASSED [ 14%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cuda-float32] SKIPPED (PyTorch test requires cuda, which is not available) [ 18%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cpu-float64] PASSED [ 22%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-cupy.array_api-None-None] SKIPPED (cupy.array_api is not installed: not checking array_api input) [ 25%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cuda-float32] SKIPPED (PyTorch test requires cuda, which is not available) [ 29%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cpu-float32] PASSED [ 33%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-mps-float32] PASSED [ 37%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cpu-float32] PASSED [ 40%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cuda-float64] SKIPPED (PyTorch test requires cuda, which is not available) [ 44%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-numpy-None-None] PASSED [ 48%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cuda-float64] SKIPPED (PyTorch test requires cuda, which is not available) [ 51%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cpu-float64] PASSED [ 55%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cuda-float32] SKIPPED (PyTorch test requires cuda, which is not available) [ 59%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-numpy.array_api-None-None] PASSED [ 62%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cuda-float64] SKIPPED (PyTorch test requires cuda, which is not available) [ 66%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-numpy.array_api-None-None] PASSED [ 70%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cpu-float32] PASSED [ 74%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-numpy-None-None] PASSED [ 77%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-numpy.array_api-None-None] PASSED [ 81%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-cupy.array_api-None-None] SKIPPED (cupy.array_api is not installed: not checking array_api input) [ 85%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-cupy-None-None] SKIPPED (cupy is not installed: not checking array_api input) [ 88%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-cupy-None-None] SKIPPED (cupy is not installed: not checking array_api input) [ 92%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-numpy-None-None] PASSED [ 96%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-cupy-None-None] SKIPPED (cupy is not installed: not checking array_api input) [100%]
We also need a changelog entry along the lines of "fix train_test_split" on CuPy arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@betatim I pushed an extra test case to fix the coverage. I let you fix the conflicts.
c71fddd
to
e99be1a
Compare
I pushed again my test case after the move of the test function to the new file introduced in |
The line not covered by our tests as reported by codecov should be executed when running Array API tests with libraries that do not support complex dtypes such as cupy. |
@adrinjalali I think this PR is ready for a second review. |
For information, I re-ran the current state of this PR with cupy and torch/cuda and all tests are green. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
complex_array_key = xp.asarray([1 + 1j, 2 + 2j, 3 + 3j]) | ||
except TypeError: | ||
# Complex numbers are not supported by all Array API libraries. | ||
complex_array_key = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like a legit codecov complaint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is covered when running the tests with cupy, however this requires the cuda CI being designed at #24491.
But we won't be able to get coverage data on a weekly run though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So shall we ignore the codecov comment for now or do we need to add a comment to mark it as ignored or something?
Reference Issues/PRs
Follow up to #26855
What does this implement/fix? Explain your changes.
This fixes the array API implementation of
train_test_split
. There were a few parts oftrain_test_split
that appeared to work but didn't actually.Any other comments?
This includes all of #27904. Once it is merged this PR needs rebasing to remove those changes. The relevant changes are in the final commit of this PR.