Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 71 additions & 11 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,29 +182,89 @@ Tools
Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.

Type of return values and fitted attributes
-------------------------------------------
Input and output array type handling
====================================

When calling functions or methods with Array API compatible inputs, the
convention is to return array values of the same array container type and
Estimators and scoring functions are able to accept input arrays
from different array libraries and/or devices. When a mixed set of input arrays is
passed, scikit-learn converts arrays as needed to make them all consistent.

For estimators, the rule is **"everything follows `X`"** - mixed array inputs are
converted so that they all match the array library and device of `X`.
For scoring functions the rule is **"everything follows `y_pred`"** - mixed array
inputs are converted so that they all match the array library and device of `y_pred`.

When a function or method has been called with array API compatible inputs, the
convention is to return arrays from the same array library and on the same
device as the input data.

Similarly, when an estimator is fitted with Array API compatible inputs, the
fitted attributes will be arrays from the same library as the input and stored
on the same device. The `predict` and `transform` method subsequently expect
Estimators
----------

When an estimator is fitted with an array API compatible `X`, all other
array inputs, including constructor arguments, (e.g., `y`, `sample_weight`)
will be converted to match the array library and device of `X`, if they do not already.
This behaviour enables switching from processing on the CPU to processing
on the GPU at any point within a pipeline.

This allows estimators to accept mixed input types, enabling `X` to be moved
to a different device within a pipeline, without explicitly moving `y`.
Note that scikit-learn pipelines do not allow transformation of `y` (to avoid
:ref:`leakage <data_leakage>`).
Comment on lines +210 to +213
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This allows estimators to accept mixed input types, enabling `X` to be moved
to a different device within a pipeline, without explicitly moving `y`.
Note that scikit-learn pipelines do not allow transformation of `y` (to avoid
:ref:`leakage <data_leakage>`).
This behaviour enables pipelines to switch from processing on
the CPU to processing on the GPU at a specific point in the pipeline.

This is still a bit clunky :-/ I think it is enough to mention the use-case and then show the example below, which contains a longer explanation/reminder about the fact that you can't move y.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about:
"This behaviour enables pipelines to switch from processing on
the CPU to processing on the GPU within the pipeline."
?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe even:

This behaviour enables switching from processing on the CPU to processing on the GPU at any point within a pipeline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, changed to this!


Take for example a pipeline where `X` and `y` both start on CPU, and go through
the following three steps:

* :class:`~sklearn.preprocessing.TargetEncoder`, which will transform categorial
`X` but also requires `y`, meaning both `X` and `y` need to be on CPU.
* :class:`FunctionTransformer(func=partial(torch.asarray, device="cuda")) <sklearn.preprocessing.FunctionTransformer>`,
which moves `X` to GPU, to improve performance in the next step.
* :class:`~sklearn.linear_model.Ridge`, whose performance can be improved when
passed arrays on a GPU, as they can handle large matrix operations very efficiently.

`X` initially contains categorical string data (thus needs to be on CPU), which is
target encoded to numerical values in :class:`~sklearn.preprocessing.TargetEncoder`.
`X` is then explicitly moved to GPU to improve the performance of
:class:`~sklearn.linear_model.Ridge`. `y` cannot be transformed by the pipeline
(recall scikit-learn pipelines do not allow transformation of `y`) but as
:class:`~sklearn.linear_model.Ridge` is able to accept mixed input types,
this is not a problem and the pipeline is able to be run.

The fitted attributes of an estimator fitted with an array API compatible `X`, will
be arrays from the same library as the input and stored on the same device.
The `predict` and `transform` method subsequently expect
inputs from the same array library and device as the data passed to the `fit`
method.
Comment on lines +235 to 237
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention that as of now, scikit-learn does not interfere with how array libraries handle arrays from a different namespace or device passed to predict or transform. Something like:

"If arrays from a different library or on a different device are passed, behavior depends on the array library: it may raise an error or silently convert them."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I am confused here, why would an array library be doing things within our predict and transform methods? 🤔

Copy link
Member

@StefanieSenger StefanieSenger May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When arrays from different namespaces are passed into a function, then the namespace that the function comes from determines the handling. For instance in xp.add(array1, array2 ) the namespace of xp could be torch and array1 is a torch array, but array2 is not.Then torch will handle this or raise an error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay you've raised some good points 🤔 I'm going to raise these back in the issue


Note however that scoring functions that return scalar values return Python
scalars (typically a `float` instance) instead of an array scalar value.
Scoring functions
-----------------

When an array API compatible `y_pred` is passed to a scoring function,
all other array inputs (e.g., `y_true`, `sample_weight`) will be converted
to match the array library and device of `y_pred`, if they do not already.
This allows scoring functions to accept mixed input types, enabling them to be
used within a :term:`meta-estimator` (or function that accepts estimators), with a
pipeline that moves input arrays between devices (e.g., CPU to GPU).

For example, to be able to use the pipeline described above within e.g.,
:func:`~sklearn.model_selection.cross_validate` or
:class:`~sklearn.model_selection.GridSearchCV`, the scoring function internally
called needs to be able to accept mixed input types.

The output type of scoring functions depends on the number of output values.
When a scoring function returns a scalar value, it will return a Python
scalar (typically a `float` instance) instead of an array scalar value.
For scoring functions that support :term:`multiclass` or :term:`multioutput`,
an array from the same array library and device as `y_pred` will be returned when
multiple values need to be output.

Common estimator checks
=======================

Add the `array_api_support` tag to an estimator's set of tags to indicate that
it supports the Array API. This will enable dedicated checks as part of the
it supports the array API. This will enable dedicated checks as part of the
common tests to verify that the estimators' results are the same when using
vanilla NumPy and Array API inputs.
vanilla NumPy and array API inputs.

To run these checks you need to install
`array-api-strict <https://data-apis.org/array-api-strict/>`_ in your
Expand Down