-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
DOC Clarify how mixed array input types handled in array api #31452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
24d8f1a
e2fc189
5833e9a
af147a5
3cbfeea
cc38465
d4b4423
22774b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,29 +182,89 @@ Tools | |
Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub | ||
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress. | ||
|
||
Type of return values and fitted attributes | ||
------------------------------------------- | ||
Input and output array type handling | ||
==================================== | ||
|
||
When calling functions or methods with Array API compatible inputs, the | ||
convention is to return array values of the same array container type and | ||
Estimators and scoring functions are able to accept input arrays | ||
from different array libraries and/or devices. When a mixed set of input arrays is | ||
passed, scikit-learn converts arrays as needed to make them all consistent. | ||
|
||
For estimators, the rule is **"everything follows `X`"** - mixed array inputs are | ||
converted so that they all match the array library and device of `X`. | ||
For scoring functions the rule is **"everything follows `y_pred`"** - mixed array | ||
inputs are converted so that they all match the array library and device of `y_pred`. | ||
|
||
When a function or method has been called with array API compatible inputs, the | ||
convention is to return arrays from the same array library and on the same | ||
device as the input data. | ||
|
||
Similarly, when an estimator is fitted with Array API compatible inputs, the | ||
fitted attributes will be arrays from the same library as the input and stored | ||
on the same device. The `predict` and `transform` method subsequently expect | ||
Estimators | ||
---------- | ||
|
||
When an estimator is fitted with an array API compatible `X`, all other | ||
array inputs, including constructor arguments, (e.g., `y`, `sample_weight`) | ||
will be converted to match the array library and device of `X`, if they do not already. | ||
This behaviour enables switching from processing on the CPU to processing | ||
on the GPU at any point within a pipeline. | ||
|
||
This allows estimators to accept mixed input types, enabling `X` to be moved | ||
to a different device within a pipeline, without explicitly moving `y`. | ||
Note that scikit-learn pipelines do not allow transformation of `y` (to avoid | ||
:ref:`leakage <data_leakage>`). | ||
|
||
Take for example a pipeline where `X` and `y` both start on CPU, and go through | ||
the following three steps: | ||
|
||
* :class:`~sklearn.preprocessing.TargetEncoder`, which will transform categorial | ||
`X` but also requires `y`, meaning both `X` and `y` need to be on CPU. | ||
* :class:`FunctionTransformer(func=partial(torch.asarray, device="cuda")) <sklearn.preprocessing.FunctionTransformer>`, | ||
which moves `X` to GPU, to improve performance in the next step. | ||
* :class:`~sklearn.linear_model.Ridge`, whose performance can be improved when | ||
passed arrays on a GPU, as they can handle large matrix operations very efficiently. | ||
|
||
`X` initially contains categorical string data (thus needs to be on CPU), which is | ||
target encoded to numerical values in :class:`~sklearn.preprocessing.TargetEncoder`. | ||
`X` is then explicitly moved to GPU to improve the performance of | ||
:class:`~sklearn.linear_model.Ridge`. `y` cannot be transformed by the pipeline | ||
(recall scikit-learn pipelines do not allow transformation of `y`) but as | ||
:class:`~sklearn.linear_model.Ridge` is able to accept mixed input types, | ||
this is not a problem and the pipeline is able to be run. | ||
lucyleeow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The fitted attributes of an estimator fitted with an array API compatible `X`, will | ||
be arrays from the same library as the input and stored on the same device. | ||
The `predict` and `transform` method subsequently expect | ||
inputs from the same array library and device as the data passed to the `fit` | ||
method. | ||
Comment on lines
+235
to
237
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe mention that as of now, scikit-learn does not interfere with how array libraries handle arrays from a different namespace or device passed to "If arrays from a different library or on a different device are passed, behavior depends on the array library: it may raise an error or silently convert them." There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I am confused here, why would an array library be doing things within our There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When arrays from different namespaces are passed into a function, then the namespace that the function comes from determines the handling. For instance in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay you've raised some good points 🤔 I'm going to raise these back in the issue |
||
|
||
Note however that scoring functions that return scalar values return Python | ||
scalars (typically a `float` instance) instead of an array scalar value. | ||
Scoring functions | ||
----------------- | ||
|
||
When an array API compatible `y_pred` is passed to a scoring function, | ||
all other array inputs (e.g., `y_true`, `sample_weight`) will be converted | ||
to match the array library and device of `y_pred`, if they do not already. | ||
This allows scoring functions to accept mixed input types, enabling them to be | ||
used within a :term:`meta-estimator` (or function that accepts estimators), with a | ||
pipeline that moves input arrays between devices (e.g., CPU to GPU). | ||
|
||
For example, to be able to use the pipeline described above within e.g., | ||
:func:`~sklearn.model_selection.cross_validate` or | ||
:class:`~sklearn.model_selection.GridSearchCV`, the scoring function internally | ||
called needs to be able to accept mixed input types. | ||
|
||
The output type of scoring functions depends on the number of output values. | ||
When a scoring function returns a scalar value, it will return a Python | ||
scalar (typically a `float` instance) instead of an array scalar value. | ||
For scoring functions that support :term:`multiclass` or :term:`multioutput`, | ||
an array from the same array library and device as `y_pred` will be returned when | ||
multiple values need to be output. | ||
|
||
Common estimator checks | ||
======================= | ||
|
||
Add the `array_api_support` tag to an estimator's set of tags to indicate that | ||
it supports the Array API. This will enable dedicated checks as part of the | ||
it supports the array API. This will enable dedicated checks as part of the | ||
common tests to verify that the estimators' results are the same when using | ||
vanilla NumPy and Array API inputs. | ||
vanilla NumPy and array API inputs. | ||
|
||
To run these checks you need to install | ||
`array-api-strict <https://data-apis.org/array-api-strict/>`_ in your | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still a bit clunky :-/ I think it is enough to mention the use-case and then show the example below, which contains a longer explanation/reminder about the fact that you can't move
y
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about:
"This behaviour enables pipelines to switch from processing on
the CPU to processing on the GPU within the pipeline."
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe even:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, changed to this!