Skip to content

[MRG] Test __array_function__ not called in non-estimator API (#15865) #18292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

alexshacked
Copy link
Contributor

@alexshacked alexshacked commented Aug 29, 2020

In NEP18 numpy created a dispatch mechanism that enables data types to override the implementation of numpy functions.
The principle is that when a numpy function receives input parameters that implement __array_function__ it will delegate the execution to the __array_function__ of this parameters.
So if we pass to a numpy function parameters of type dask.array (that implements __array_function__) the result will not be a numpy.ndarray as before but a dask.array instead. This can cause troubles to the scikit-learn code.
in PR #14702 it was made sure that all scikit-learn estimators are protected from such a scenario.
In this PR we will ensure that other functions in scikit-learn like cross_validate or permutation_importance are also protected.

@cmarmo
Copy link
Contributor

cmarmo commented Jan 21, 2021

Hi @alexshacked, thanks for your pull request and for your patience! If you haven't lost hope... do you mind modifying the title of your PR from [WIP] to [MRG]? This will (hopefully) bring some attention back... together with the "Waiting for reviewer" label. Thanks again!

Base automatically changed from master to main January 22, 2021 10:53
@alexshacked alexshacked changed the title [WIP] Test __array_function__ not called in non-estimator API (#15865) [MRG] Test __array_function__ not called in non-estimator API (#15865) Jan 24, 2021
@alexshacked
Copy link
Contributor Author

Hi @cmarmo. Thank you for noticing this issue. Moved from [WIP] to [MRG]. I will be happy to implement comments and bring this PR to completion.

@cmarmo
Copy link
Contributor

cmarmo commented Jan 25, 2021

The failing checks are related to the renaming of the main branch. I'm closing and reopening the PR to re-trigger the checks.

@cmarmo cmarmo closed this Jan 25, 2021
@cmarmo cmarmo reopened this Jan 25, 2021
@cmarmo
Copy link
Contributor

cmarmo commented Jan 24, 2022

Hi @alexshacked thank you for your patience! If you are still interested, do you mind synchronizing to upstream/main? Then perhaps @thomasjpfan or @jnothman (who opened the related issue) could have a look at this PR? Thanks!

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @alexshacked

I think it's okay to have one-of tests for __array_function__ + functions for now.



def test_array_function_not_called():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To provide context for future maintainers that are not familiar with NEP18:

Suggested change
def test_array_function_not_called():
def test_array_function_not_called():
"""Check that `__array_function__` (NEP18) is not called."""

Small nit: may this test be moved to the end of the file?

@@ -226,6 +227,17 @@ def get_params(self, deep=False):
return {'a': self.a, 'allow_nd': self.allow_nd}


def test_array_function_not_called():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here and nit about moving the test to the end of the file.

Suggested change
def test_array_function_not_called():
def test_array_function_not_called():
"""Check that `__array_function__` (NEP18) is not called."""

estimator = LogisticRegression()
estimator.fit(X, y)
rng = np.random.RandomState(42)
permutation_importance(estimator, X, y, n_repeats=5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the test is based on _NotAnArray not raising an error, we do not need to repeat too much:

Suggested change
permutation_importance(estimator, X, y, n_repeats=5,
permutation_importance(estimator, X, y, n_repeats=2,

estimator.fit(X, y)
rng = np.random.RandomState(42)
permutation_importance(estimator, X, y, n_repeats=5,
random_state=rng, n_jobs=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to include n_jobs=1 since it is the default when n_jobs=None.

Suggested change
random_state=rng, n_jobs=1)
random_state=rng)

Comment on lines +237 to +238
grid = GridSearchCV(estimator, param_grid={'C': [1, 10]})
cross_validate(grid, X, y, n_jobs=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need nested cross validation to check that cross_validate itself does not use __array_function__:

Suggested change
grid = GridSearchCV(estimator, param_grid={'C': [1, 10]})
cross_validate(grid, X, y, n_jobs=2)
cross_validate(estimator, X, y, cv=2)

Also cv=2 to speed the test up and leaving n_jobs=None as default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants