-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Add Array API compatibility tests for *SearchCV
classes
#27096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
RandomizedSearchCV and GridSearchCV appear to just work with Array API inputs.
970c612
to
6f74abb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Can you please add those estimators to the list of estimators in the array api section of the user guide?
This is now the case. Let's see if this can work. |
I started some of the failures reported by our test suite after the merge with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking this PR as not accepted because there is still work to do.
train=None, | ||
test=None, | ||
train=train, | ||
test=test, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why this did no fail in main
but based on the docstrings _fit_and_score
was never supposed to accept None
for its train
and test
arguments so better pass valid integer arrays in this test instead.
@@ -332,7 +332,9 @@ def _generate_search_cv_instances(): | |||
extra_params = ( | |||
{"min_resources": "smallest"} if "min_resources" in init_params else {} | |||
) | |||
search_cv = SearchCV(Estimator(), param_grid, cv=2, **extra_params) | |||
search_cv = SearchCV( | |||
Estimator(), param_grid, cv=2, error_score="raise", **extra_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting error_score="raise"
makes pytest traceback reading much more direct, especially in CI logs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is +1 on my side. The fact that the CV splitters are currently NumPy only makes the conversion for stratifification a bit ugly.
We probably need a follow-up PR to check that all CV splitters work when splitting array API inputs with train_test_split
or other tools that accept array API inputs and cv=
parameter but I would rather do that in a dedicated PR and focus this one on what is necessary for *SearchCV
themselves.
/cc @OmarManzoor and @betatim.
BTW, I tested this PR on a cuda host and on an MPS host and all tests pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Not sure I should merge it because I opened the PR. But voting to merge it anyway :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# we need the following explicit conversion: | ||
xp, is_array_api = get_namespace(y) | ||
if is_array_api: | ||
y = _convert_to_numpy(y, xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to add a test to cover this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we had coverage for it via the test I removed in the last commit. But I thought that we tested the same thing via a common test. So I am a bit puzzled why that doesn't happen :-/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand what's going on:
- the common tests run the search cv meta estimator using
LogisticRegression
andRidge
as base-estimator. - the
LogisticRegression
test is skipped becauseLogisticRegression
does not have the array API estimator tag hence the wrapping search cv estimators ain't either; - the stratified k-fold CV splitter is only used for classification problem, hence the
Ridge
-based common test does not cover it; - the previous hard-coded test of this PR used
LinearDiscriminantAnalysis
which is a classifier and supports array API, hence could cover this line.
Since adding array API support to LogisticRegression
might take a bit of time, I would be in favor of re-adding the previous non-common test that was removed from this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that seems right. Thanks for the explanation! I think the test seems valid for now and not redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ogrisel I added the test back. Can you have a look and then we can probably merge.
BTW, I tested on CUDA with google colab and https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c and tests are green. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I marked this PR for auto-merge. Thanks all! |
RandomizedSearchCV and GridSearchCV appear to just work with Array API inputs.
This adds a test that makes sure that they will keep working.
For the common tests to pass we need
Ridge
to support the Array API.