Skip to content

ENH Add Array API compatibility tests for *SearchCV classes #27096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 7, 2024

Conversation

betatim
Copy link
Member

@betatim betatim commented Aug 18, 2023

RandomizedSearchCV and GridSearchCV appear to just work with Array API inputs.

This adds a test that makes sure that they will keep working.

For the common tests to pass we need Ridge to support the Array API.

@github-actions
Copy link

github-actions bot commented Aug 18, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: aebb43e. Link to the linter CI: here

RandomizedSearchCV and GridSearchCV appear to just work with Array API
inputs.
@betatim betatim force-pushed the array-api-randomsearch branch from 970c612 to 6f74abb Compare August 18, 2023 10:01
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Can you please add those estimators to the list of estimators in the array api section of the user guide?

@ogrisel
Copy link
Member

ogrisel commented May 24, 2024

For the common tests to pass we need Ridge to support the Array API.

This is now the case. Let's see if this can work.

@ogrisel
Copy link
Member

ogrisel commented May 24, 2024

I started some of the failures reported by our test suite after the merge with main but it's still WIP. I have to tune out for today.

@ogrisel ogrisel marked this pull request as draft May 24, 2024 16:30
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking this PR as not accepted because there is still work to do.

train=None,
test=None,
train=train,
test=test,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this did no fail in main but based on the docstrings _fit_and_score was never supposed to accept None for its train and test arguments so better pass valid integer arrays in this test instead.

@@ -332,7 +332,9 @@ def _generate_search_cv_instances():
extra_params = (
{"min_resources": "smallest"} if "min_resources" in init_params else {}
)
search_cv = SearchCV(Estimator(), param_grid, cv=2, **extra_params)
search_cv = SearchCV(
Estimator(), param_grid, cv=2, error_score="raise", **extra_params
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting error_score="raise" makes pytest traceback reading much more direct, especially in CI logs.

@ogrisel ogrisel added this to the 1.6 milestone May 27, 2024
@ogrisel ogrisel marked this pull request as ready for review May 27, 2024 10:00
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is +1 on my side. The fact that the CV splitters are currently NumPy only makes the conversion for stratifification a bit ugly.

We probably need a follow-up PR to check that all CV splitters work when splitting array API inputs with train_test_split or other tools that accept array API inputs and cv= parameter but I would rather do that in a dedicated PR and focus this one on what is necessary for *SearchCV themselves.

/cc @OmarManzoor and @betatim.

@ogrisel
Copy link
Member

ogrisel commented May 27, 2024

BTW, I tested this PR on a cuda host and on an MPS host and all tests pass.

Copy link
Member Author

@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Not sure I should merge it because I opened the PR. But voting to merge it anyway :D

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# we need the following explicit conversion:
xp, is_array_api = get_namespace(y)
if is_array_api:
y = _convert_to_numpy(y, xp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a test to cover this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had coverage for it via the test I removed in the last commit. But I thought that we tested the same thing via a common test. So I am a bit puzzled why that doesn't happen :-/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what's going on:

  • the common tests run the search cv meta estimator using LogisticRegression and Ridge as base-estimator.
  • the LogisticRegression test is skipped because LogisticRegression does not have the array API estimator tag hence the wrapping search cv estimators ain't either;
  • the stratified k-fold CV splitter is only used for classification problem, hence the Ridge-based common test does not cover it;
  • the previous hard-coded test of this PR used LinearDiscriminantAnalysis which is a classifier and supports array API, hence could cover this line.

Since adding array API support to LogisticRegression might take a bit of time, I would be in favor of re-adding the previous non-common test that was removed from this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that seems right. Thanks for the explanation! I think the test seems valid for now and not redundant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogrisel I added the test back. Can you have a look and then we can probably merge.

@ogrisel
Copy link
Member

ogrisel commented Jun 7, 2024

BTW, I tested on CUDA with google colab and https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c and tests are green.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @betatim @ogrisel

@ogrisel ogrisel enabled auto-merge (squash) June 7, 2024 09:46
@ogrisel
Copy link
Member

ogrisel commented Jun 7, 2024

I marked this PR for auto-merge. Thanks all!

@ogrisel ogrisel merged commit 5692e59 into scikit-learn:main Jun 7, 2024
28 checks passed
@jeremiedbb jeremiedbb mentioned this pull request Jul 2, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

3 participants