-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FEA Add a private check_array
with additional parameters
#25617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
To allow plugins or other users to pass in a custom `asarray` function an additional parameter is added to a private version of `check_array`. The reason for the indirection is to avoid increasing the number of public parameters that need to then follow the deprecation policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For GPUS, can one use the __cuda_array_interface__
to have a zero-copy view of the array in CuPy
and then make use check_array
's ArrayAPI support?
@@ -876,7 +1012,7 @@ def check_array( | |||
) | |||
array = xp.astype(array, dtype, copy=False) | |||
else: | |||
array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp) | |||
array = asarray(array, order=order, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the third party array libraries you are considering, do they usually work with NumPy dtypes, such as np.float32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so indeed. It's the case for CuPy
and dpctl.tensor
which are our main use cases as part of the plugin API at the moment but it's probably the problem of the implementer of the custom asarray
to make sure that they understand the usual dtypes we use in scikit-learn.
array. Useful when the input array is not a Numpy array or when the | ||
converted array should be a ndarray from a differnt library. The callable | ||
should have the same signature as `np.asarray` and in addition support | ||
they `copy` keyword argument. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the function signature be written out here? asarray(a, dtype=None, order=None, copy=False)
(np.asarray has a like
kwarg which is not required)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm still thinking about what the best way is to describe how this asarray
-like function should behave. Both in terms of its signature and what it can and can not return (your later comment).
@@ -947,13 +1083,19 @@ def check_array( | |||
if xp.__name__ in {"numpy", "numpy.array_api"}: | |||
# only make a copy if `array` and `array_orig` may share memory` | |||
if np.may_share_memory(array, array_orig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If array
is a GPU array, I suspect np.may_share_memory
will not work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then we won't go down this path no? At least I'd assume xp.__name__
would not be "numpy" in that case.
@@ -728,6 +852,15 @@ def check_array( | |||
|
|||
.. versionadded:: 1.1.0 | |||
|
|||
asarray : callable, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The array return by asarray
is required to have an interface that works with check_array
. Specifically, the array object needs:
array.ndim
array.dtype.kind
array.shape
- Work with
_assert_all_finite
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference on dtype.kind
: Not all array libraries define a dtype.kind
attribute with the same semantics as NumPy. The Array API spec now has a isdtype, which is used to determine the "kind" of a dtype.
Yes and no. Right now with @fcharras had an additional use-case around It feels like a lot of the time what I want as a plugin author is "Dear input validation, please do everything like before, but when you would call |
I was thinking: cupy_array_api = cupy.array_api.asarray(custom_library_gpu_array)
with config_context(array_api_dispatch=True):
validated_cupy_array = check_array(cupy_array_api)
regular_cupy_array = validated_cupy_array._array
validated_custom_library_gpu_array = your_gpu_library.asarray(regular_cupy_array) I suspect the other use case is order, which is not part of the Array API specification. In any case, I see the motivation for allowing a custom |
To allow plugins or other users to pass in a custom
asarray
function an additional parameter is added to a private version ofcheck_array
. If the caller passes something, that callable is used instead of the defaultnp.asarray
. This is useful if you want to directly convert to a, say, cupy array. Without this parameter first a numpy array would be created which you then convert to a cupy array. This also makes it possible for cupy arrays to be passed in to_validate_data
. This is useful for plugin authors who want to re-use the validation methods of scikit-learn, instead of having to maintain a copy of them.The reason for the indirection is to avoid increasing the number of public parameters that need to then follow the deprecation policy.
Reference Issues/PRs
This is an attempt to closes #25433
What does this implement/fix? Explain your changes.
This implements the "add a private version with this additional parameter" idea from #25433.