-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Test __array_function__ not called in non-estimator API (#15865) #18292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ot activated in sklearn flows(scikit-learn#15865)
Hi @alexshacked, thanks for your pull request and for your patience! If you haven't lost hope... do you mind modifying the title of your PR from [WIP] to [MRG]? This will (hopefully) bring some attention back... together with the "Waiting for reviewer" label. Thanks again! |
Hi @cmarmo. Thank you for noticing this issue. Moved from [WIP] to [MRG]. I will be happy to implement comments and bring this PR to completion. |
The failing checks are related to the renaming of the main branch. I'm closing and reopening the PR to re-trigger the checks. |
Hi @alexshacked thank you for your patience! If you are still interested, do you mind synchronizing to upstream/main? Then perhaps @thomasjpfan or @jnothman (who opened the related issue) could have a look at this PR? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @alexshacked
I think it's okay to have one-of tests for __array_function__
+ functions for now.
|
||
|
||
def test_array_function_not_called(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To provide context for future maintainers that are not familiar with NEP18:
def test_array_function_not_called(): | |
def test_array_function_not_called(): | |
"""Check that `__array_function__` (NEP18) is not called.""" |
Small nit: may this test be moved to the end of the file?
@@ -226,6 +227,17 @@ def get_params(self, deep=False): | |||
return {'a': self.a, 'allow_nd': self.allow_nd} | |||
|
|||
|
|||
def test_array_function_not_called(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here and nit about moving the test to the end of the file.
def test_array_function_not_called(): | |
def test_array_function_not_called(): | |
"""Check that `__array_function__` (NEP18) is not called.""" |
estimator = LogisticRegression() | ||
estimator.fit(X, y) | ||
rng = np.random.RandomState(42) | ||
permutation_importance(estimator, X, y, n_repeats=5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the test is based on _NotAnArray
not raising an error, we do not need to repeat too much:
permutation_importance(estimator, X, y, n_repeats=5, | |
permutation_importance(estimator, X, y, n_repeats=2, |
estimator.fit(X, y) | ||
rng = np.random.RandomState(42) | ||
permutation_importance(estimator, X, y, n_repeats=5, | ||
random_state=rng, n_jobs=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not need to include n_jobs=1
since it is the default when n_jobs=None
.
random_state=rng, n_jobs=1) | |
random_state=rng) |
grid = GridSearchCV(estimator, param_grid={'C': [1, 10]}) | ||
cross_validate(grid, X, y, n_jobs=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think we need nested cross validation to check that cross_validate
itself does not use __array_function__
:
grid = GridSearchCV(estimator, param_grid={'C': [1, 10]}) | |
cross_validate(grid, X, y, n_jobs=2) | |
cross_validate(estimator, X, y, cv=2) |
Also cv=2
to speed the test up and leaving n_jobs=None
as default.
In NEP18 numpy created a dispatch mechanism that enables data types to override the implementation of numpy functions.
The principle is that when a numpy function receives input parameters that implement __array_function__ it will delegate the execution to the __array_function__ of this parameters.
So if we pass to a numpy function parameters of type dask.array (that implements __array_function__) the result will not be a numpy.ndarray as before but a dask.array instead. This can cause troubles to the scikit-learn code.
in PR #14702 it was made sure that all scikit-learn estimators are protected from such a scenario.
In this PR we will ensure that other functions in scikit-learn like
cross_validate
orpermutation_importance
are also protected.