Skip to content

FIX sparse arrays not included in array API dispatch #29476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 11, 2024

Conversation

adrinjalali
Copy link
Member

Alternative to #29466 and #29470

Closes #29452

This version makes it explicit if sparse arrays should be excluded, and raises by default if sparse arrays are given to get_namespace

cc @ogrisel @OmarManzoor @betatim

Copy link

github-actions bot commented Jul 12, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 03d6372. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Overall, I am not opposed to making sparse handling explicit but I wonder if that won't be to invasive: I assume that whenever we call get_namespace on the inputs of a function that accepts sparse inputs, we will always have to pass remove_sparse=True which might make our code base unnecessarily verbose / complex compared to the alternative PR.

We can probably detect many of those cases by expanding our estimators checks to call _check_estimator_sparse_container to call the scikit-learn estimators on sparse data both with array API dispatching enabled and disabled.

Aside from this, here is some more specific inline feedback:

@@ -140,7 +141,9 @@ def device(*array_list, remove_none=True, remove_types=(str,)):
*array_list, remove_none=remove_none, remove_types=remove_types
)

# Note that _remove_non_arrays ensures that array_list is not empty.
if not array_list:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the Returns section of the docstring needs to be updated to mention this case.

Copy link
Member

@ogrisel ogrisel Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I we might argue that returning None in this case can be quite confusing.

When reviewing, this PR, I found out that get_namespace_and_device is not even consistent with device (independently of the changes in this PR).

Consider the following code on a single non-array input that is not filtered out by any of the remove_ arguments (note that the array API dispatch is never enabled) .

>>> from sklearn.utils._array_api import device, get_namespace, get_namespace_and_device
>>> some_list = [0, 1, 2]
>>> get_namespace_and_device(some_list)
(<sklearn.utils._array_api._NumPyAPIWrapper object at 0x10ff31750>, False, None)
>>> xp, is_enabled, device_ = get_namespace_and_device(some_list)
>>> device(some_list)
'cpu'
>>> device(xp.asarray(some_list))
'cpu'

I think we should at least make the output of device(some_list) and get_namespace_and_device(some_list) consistent, either by returning None in both cases, or by returning "cpu" consistently as is done for NumPy arrays, either with the sklearn.utils._array_api._NumPyAPIWrapper when array API dispatch is disabled, or when using the array_api_compat.numpy wrapper for numpy 1.x or when using numpy 2 arrays who naturally expose a .device == "cpu" attribute (at least in the dev version).

Returning None for non-array inputs (such as lists or scipy sparse matrices when accepted) would be the cleanest. However that might render our code more verbose by having to scatter around if device_ is not None in various places. So maybe, always returning "cpu", both in device(...) and get_namespace_and_device(...) for non-arrays inputs might be more pragmatic.

@betatim
Copy link
Member

betatim commented Aug 7, 2024

I'm trying to get back into understanding the issue that lead to this PR and #29466.

One thing I can't work out is why in this PR we make it optional (with default False) to remove sparse arrays from the computation of what the array namespace is? Wouldn't you always want to set it to True, in which case why make it an option?

The reason I'm thinking this is that, in a way, a sparse array is similar to when a str is passed in to get_namespace: it doesn't influence the decision.


Trying to get some things straight in my head I ended up making this table for a function that takes two arguments A and B, the different values of set_config(array_api_dispatch=..) and then what the output of get_namespace should be.

A B array API get_namespace output (xp, compliant)
numpy sparse off np, False
numpy sparse on np, True
sparse sparse off np, False
sparse sparse on np, True
jax sparse off np, False (or should this error? what happens today?)
jax sparse on jax, True

For the third row I think we should do what we do today, which I think is use np. in combination with if sp.issparse.

For the fourth row we would need to make sure that we can detect that "list of arrays passed to get_namespace is empty because it only contained sparse arrays" and instead of raising an error, return np, True(?). I think this should effectively go down the same code path as row three. The fact that array API dispatching is turned on is irrelevant.

For the second to last row I don't know what should happen. Probably the same that happens today (an error?).

For the last row I think you should probably return jax as he xp. But, I am not sure if all the code that we have today that uses np.foo on a sparse array will work if we rewrite it as xp.foo. Which means we would have to add additional if sp.issparse guards, so that we can explicitly use np.foo. Does someone know the answer to this?

Are there cases missed in the table? I think trying to list them all helps (me) trying to come up with a consistent picture and then solution.

@ogrisel
Copy link
Member

ogrisel commented Aug 7, 2024

jax 	sparse 	off 	np, False (or should this error? what happens today?)

Today (and before we started integrating array API support into scikit-learn) we would typically call np.asarray on the jax input as part of the check_array calls in _validate_data and that should work as long as the input datastructure implements __array__.

For pytorch inputs it would work silently for "cpu" allocated tensors but would fail for device allocated tensor because pytorch conservatively refuses to perform implicit memory transfer from device to host. But this is not mandated by the array API spec itself. I have not checked what jax does. We don't even test with jax at the moment.

@ogrisel
Copy link
Member

ogrisel commented Aug 7, 2024

sparse sparse on np, True

I think I would rather return the _NUMPY_API_WRAPPER_INSTANCE, False as currently done in this PR instead.

@betatim
Copy link
Member

betatim commented Aug 7, 2024

Today (and before we started integrating array API support into scikit-learn) we would typically call np.asarray on the jax input as part of the check_array calls in _validate_data and that should work as long as the input datastructure implements __array__.

For pytorch inputs it would work silently for "cpu" allocated tensors but would fail for device allocated tensor because pytorch conservatively refuses to perform implicit memory transfer from device to host. But this is not mandated by the array API spec itself. I have not checked what jax does. We don't even test with jax at the moment.

Maybe we can summarise today's status it as "might work or maybe not, undefined"?

I used jax as a random array library, not because I wanted to think about jax in particular :-/

@ogrisel
Copy link
Member

ogrisel commented Aug 8, 2024

One thing I can't work out is why in this PR we make it optional (with default False) to remove sparse arrays from the computation of what the array namespace is? Wouldn't you always want to set it to True, in which case why make it an option?

I agree. I am in favor of filtering sparse inputs by default (and probably also removing the option to not filter) while keeping the behavior of this PR.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth adding a test to check that the device is none when we only have scipy sparse arrays or a combination of sparse arrays and types which we want to remove? Also maybe a test when we have a simple numpy array and a sparse array, I think it might return "cpu" in that case? These are tests when array api dispatch is disabled.

@adrinjalali
Copy link
Member Author

Added the test for the behavior.

adrinjalali and others added 3 commits August 23, 2024 12:34
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @adrinjalali

@adrinjalali
Copy link
Member Author

@ogrisel @betatim this seems ready for a review from your side.

@adrinjalali
Copy link
Member Author

@ogrisel @betatim another ping here.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so sure about what device should be returned for edge cases where all inputs are filtered, but I think it should not matter much and since this is internal to scikit-learn we can always change our mind later if the choice in this PR proves suboptimal for one reason or another.

LGTM, thanks very much @adrinjalali!

@ogrisel ogrisel merged commit 153205a into scikit-learn:main Sep 11, 2024
30 checks passed
@adrinjalali adrinjalali deleted the array_api/sparse branch September 11, 2024 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Array API regression in homogeneity_completeness_v_measure?
4 participants