Skip to content

ENH: Add Dask Array API support #28588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 45 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
dd239fc
add all files
lithomas1 Feb 19, 2024
c1a7522
don't use np.asarray to force computation
lithomas1 Feb 19, 2024
696ed09
some mods
lithomas1 Feb 21, 2024
d0ea36f
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Mar 2, 2024
ab75783
update
lithomas1 Mar 2, 2024
2d6d2ca
change back sign coercion stuff
lithomas1 Mar 2, 2024
584c13a
update
lithomas1 Mar 3, 2024
164a066
avoid hang, but why?
lithomas1 Mar 3, 2024
634a228
fixes and remove linear model support
lithomas1 Mar 4, 2024
11d1fcd
remove test notebook
lithomas1 Mar 7, 2024
cc6cc4b
skip some tests for PCA
lithomas1 Mar 11, 2024
514abb3
update
lithomas1 Mar 13, 2024
34d1b65
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Mar 21, 2024
9637b8e
add min version check for array-api-compat
lithomas1 Mar 21, 2024
683e95a
patches
lithomas1 Mar 21, 2024
9fd95db
bump array-api-compat
lithomas1 Mar 21, 2024
263b0e6
fix r2_score
lithomas1 Mar 22, 2024
c9e789c
fix tests
lithomas1 Mar 24, 2024
fb78976
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Apr 1, 2024
cfcae45
remove dask specific changes
lithomas1 Apr 2, 2024
ff66577
Merge branch 'main' into wip-dask-array-api
lithomas1 Apr 2, 2024
f82b375
modify test machinery
lithomas1 Apr 3, 2024
b416c5e
add whatsnew note
lithomas1 Apr 3, 2024
5f56c79
more test changes
lithomas1 Apr 4, 2024
eefb016
revert more changes
lithomas1 Apr 4, 2024
8c069f6
fix last test
lithomas1 Apr 4, 2024
e1fd16f
Apply suggestions from code review
lithomas1 Apr 10, 2024
8f8a9ee
rest of suggestions
lithomas1 Apr 10, 2024
63ad0b4
address rest of comments
lithomas1 Apr 10, 2024
929ac80
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Apr 10, 2024
ba212a7
add array-api-compat to pyproject.toml
lithomas1 Apr 10, 2024
69c0318
fix linter
lithomas1 Apr 10, 2024
4f8e7bf
Fix doc
ogrisel Apr 10, 2024
543ade4
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Apr 26, 2024
f757c55
fix ridgeregression
lithomas1 May 16, 2024
854544d
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 May 16, 2024
dbb8673
Remove conflict resolution marker
ogrisel May 23, 2024
58b5ad1
Merge main and regenerate lock files
ogrisel Aug 21, 2024
b6af823
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Aug 28, 2024
041c80f
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Mar 26, 2025
a31765f
more fixes
lithomas1 Mar 30, 2025
388dc94
Merge branch 'main' into wip-dask-array-api
lithomas1 Mar 30, 2025
2bca647
WIP for LDA working
lithomas1 Mar 30, 2025
d736fd4
Merge branch 'wip-dask-array-api' of github.com:lithomas1/scikit-lear…
lithomas1 Apr 13, 2025
bd5c6e4
Merge branch 'main' of github.com:scikit-learn/scikit-learn into wip-…
lithomas1 Apr 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 58 additions & 27 deletions build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- pytest-cov
- coverage
- ccache
- dask
- pytorch
- pytorch-cpu
- polars
Expand Down
1 change: 1 addition & 0 deletions build_tools/update_environments_and_lock_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def remove_from(alist, to_remove):
"conda_dependencies": common_dependencies
+ [
"ccache",
"dask",
"pytorch",
"pytorch-cpu",
"polars",
Expand Down
4 changes: 4 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ explicitly as explained in the following.
Currently, only `array-api-strict`, `cupy`, and `PyTorch` are known to work
with scikit-learn's estimators.

`dask.array` support is incomplete at the time of writing: some methods and
estimators may not work while we work out a way to handle Dask's lazy evaluation
semantics in a library-agnostic way.

The following video provides an overview of the standard's design principles
and how it facilitates interoperability between array libraries:

Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ See :ref:`array_api` for more details.

**Classes:**

**Libraries:**

- ``dask.array`` is now experimentally supported as an array API backend.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- ``dask.array`` is now experimentally supported as an array API backend.
- ``dask.array`` is now experimentally supported as an Array API backend.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like Olivier commented on a few of these before as well. I won't comment all that I find, but it would be great to have "Array API" every where.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not to be pedantic, but I would recommend not capitalising this as per data-apis/array-api#778 !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I was not aware of this. I guess we have a lot of doc/comments to update...

Let's fix that in a dedicated PR.

Some methods and estimators may not work with dask.
:pr:`28588` by :user:`Thomas Li <lithomas1>`
- :class:`linear_model.Ridge` now supports the Array API for the `svd` solver.
See :ref:`array_api` for more details.
:pr:`27800` by :user:`Franck Charras <fcharras>`, :user:`Olivier Grisel <ogrisel>`
Expand Down
1 change: 1 addition & 0 deletions sklearn/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# It will NOT be included in setup's extras_require
# The values are (version_spec, comma separated tags)
dependent_packages = {
"array-api-compat": ("1.5.1", "tests"),
"numpy": (NUMPY_MIN_VERSION, "build, install"),
"scipy": (SCIPY_MIN_VERSION, "build, install"),
"joblib": (JOBLIB_MIN_VERSION, "install"),
Expand Down
21 changes: 21 additions & 0 deletions sklearn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,27 @@ def print_changed_only_false():
set_config(print_changed_only=True) # reset to default


@pytest.fixture
def skip_dask_array_api_compliance(request):
"""
Xfails an array API compliance test for dask.
(i.e. when the array_namespace fixture yields 'dask.array')

When using this fixture, please insert a comment to explain what particular
aspect of Dask is preventing the tests to pass, e.g. missing module level
function or array level method, value-dependent shape item values or array
assignments with a value-dependent boolean mask.
"""
array_namespace = request.getfixturevalue("array_namespace")
if array_namespace == "dask.array":
pytest.skip(
reason=(
"Estimator/method does not work because of dask array API compliance"
" issues"
Comment on lines +381 to +382
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Estimator/method does not work because of dask array API compliance"
" issues"
"Estimator/method does not work because of missing dask Array API compliance."

WDYT?

)
)


if dt_config is not None:
# Strict mode to differentiate between 3.14 and np.float64(3.14)
dt_config.strict_check = True
6 changes: 3 additions & 3 deletions sklearn/decomposition/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_covariance(self):
xp.asarray(0.0, device=device(exp_var), dtype=exp_var.dtype),
)
cov = (components_.T * exp_var_diff) @ components_
_fill_or_add_to_diagonal(cov, self.noise_variance_, xp)
cov = _fill_or_add_to_diagonal(cov, self.noise_variance_, xp)
return cov

def get_precision(self):
Expand Down Expand Up @@ -89,10 +89,10 @@ def get_precision(self):
xp.asarray(0.0, device=device(exp_var)),
)
precision = components_ @ components_.T / self.noise_variance_
_fill_or_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
precision = _fill_or_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
precision = components_.T @ linalg_inv(precision) @ components_
precision /= -(self.noise_variance_**2)
_fill_or_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
precision = _fill_or_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
return precision

@abstractmethod
Expand Down
9 changes: 8 additions & 1 deletion sklearn/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.decomposition import PCA
from sklearn.decomposition._pca import _assess_dimension, _infer_dimension
from sklearn.utils._array_api import (
_array_api_skips,
_atol_for_type,
_convert_to_numpy,
_get_namespace_device_dtype_ids,
Expand Down Expand Up @@ -1036,8 +1037,14 @@ def check_array_api_get_precision(name, estimator, array_namespace, device, dtyp
def test_pca_array_api_compliance(
estimator, check, array_namespace, device, dtype_name
):
skip_methods = _array_api_skips["PCA"]
name = estimator.__class__.__name__
check(name, estimator, array_namespace, device=device, dtype_name=dtype_name)
kwargs = {}
if check is check_array_api_input_and_values:
kwargs = {"skip_methods": skip_methods}
check(
name, estimator, array_namespace, device=device, dtype_name=dtype_name, **kwargs
)


@pytest.mark.parametrize(
Expand Down
36 changes: 23 additions & 13 deletions sklearn/discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .covariance import empirical_covariance, ledoit_wolf, shrunk_covariance
from .linear_model._base import LinearClassifierMixin
from .preprocessing import StandardScaler
from .utils._array_api import _expit, device, get_namespace, size
from .utils._array_api import _expit, device, get_namespace, size, xpx
from .utils._param_validation import HasMethods, Interval, StrOptions
from .utils.extmath import softmax
from .utils.multiclass import check_classification_targets, unique_labels
Expand Down Expand Up @@ -106,11 +106,12 @@ def _class_means(X, y):
Class means.
"""
xp, is_array_api_compliant = get_namespace(X)
classes, y = xp.unique_inverse(y)
means = xp.zeros((classes.shape[0], X.shape[1]), device=device(X), dtype=X.dtype)
n_classes = xpx.nunique(y)
_, y = xp.unique_inverse(y)
means = xp.zeros((n_classes, X.shape[1]), device=device(X), dtype=X.dtype)

if is_array_api_compliant:
for i in range(classes.shape[0]):
for i in range(n_classes):
means[i, :] = xp.mean(X[y == i], axis=0)
else:
# TODO: Explore the choice of using bincount + add.at as it seems sub optimal
Expand Down Expand Up @@ -577,20 +578,29 @@ def _solve_svd(self, X, y):
svd = scipy.linalg.svd

n_samples, n_features = X.shape
n_classes = self.classes_.shape[0]
# TODO: this is a duplicate computation
# in
n_classes = int(xpx.nunique(y))

self.means_ = _class_means(X, y)
if self.store_covariance:
self.covariance_ = _class_cov(X, y, self.priors_)

Xc = []
for idx, group in enumerate(self.classes_):
Xg = X[y == group]
Xc.append(Xg - self.means_[idx, :])

self.xbar_ = self.priors_ @ self.means_

Xc = xp.concat(Xc, axis=0)
def calc_xc(classes):
Xc = []
for idx, group in enumerate(classes):
Xg = X[y == group]
Xc.append(Xg - self.means_[idx, :])
Xc = xp.concat(Xc, axis=0)
return Xc

Xc = xpx.lazy_apply(calc_xc, self.classes_, shape=(n_samples, n_features))
self.xbar_ = xpx.lazy_apply(
lambda priors, means: priors @ means,
self.priors_,
self.means_,
shape=(n_classes, n_features),
)

# 1) within (univariate) scaling by with classes std-dev
std = xp.std(Xc, axis=0)
Expand Down
7 changes: 5 additions & 2 deletions sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,13 @@ def _solve_svd(X, y, alpha, xp=None):
xp, _ = get_namespace(X, xp=xp)
U, s, Vt = xp.linalg.svd(X, full_matrices=False)
idx = s > 1e-15 # same default value as scipy.linalg.pinv
s_nnz = s[idx][:, None]
s = s[:, None]
UTy = U.T @ y
d = xp.zeros((s.shape[0], alpha.shape[0]), dtype=X.dtype, device=device(X))
d[idx] = s_nnz / (s_nnz**2 + alpha)

# Use where to do the equivalent of boolean indexing for
d = xp.where(idx[:, None], s / (s**2 + alpha), d)

d_UT_y = d * UTy
return (Vt.T @ d_UT_y).T

Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ def _assemble_r2_explained_variance(
# Non-zero Numerator and Non-zero Denominator: use the formula
valid_score = nonzero_denominator & nonzero_numerator

output_scores[valid_score] = 1 - (
numerator[valid_score] / denominator[valid_score]
output_scores = xp.where(
valid_score, 1 - numerator / denominator, output_scores
)

# Non-zero Numerator and Zero Denominator:
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
if X is Y:
_fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)
distances = _fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)

if squared:
return distances
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def cosine_distances(X, Y=None):
if X is Y or Y is None:
# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
_fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
S = _fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
return S


Expand Down
7 changes: 7 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
from sklearn.utils._array_api import (
_array_api_skips,
_atol_for_type,
_convert_to_numpy,
_get_namespace_device_dtype_ids,
Expand Down Expand Up @@ -1843,6 +1844,12 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
def check_array_api_metric(
metric, array_namespace, device, dtype_name, a_np, b_np, **metric_kwargs
):
func_name = metric.func.__name__ if isinstance(metric, partial) else metric.__name__
if _array_api_skips.get(func_name, {}).get(array_namespace) == "all":
pytest.skip(
f"{array_namespace} is not Array API compliant for {metric.__name__}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"{array_namespace} is not Array API compliant for {metric.__name__}"
f"Skipping {metric.__name__} because of missing Array API compliance in {array_namespace}"

It feels like the current text is backwards. What do you think of this change? Trying to make it clear that "we are skipping testing for X because something in {array_namespace}'s Array API support is missing" - which is what I think we are doing.

)

xp = _array_api_for_tests(array_namespace, device)

a_xp = xp.asarray(a_np, device=device)
Expand Down
34 changes: 32 additions & 2 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,27 @@
from .._config import get_config
from ..externals import array_api_compat
from ..externals import array_api_extra as xpx
from ..externals.array_api_compat import is_dask_namespace, is_lazy_array
from ..externals.array_api_compat import numpy as np_compat
from .fixes import parse_version

# TODO: complete __all__
__all__ = ["xpx"] # we import xpx here just to re-export it, need this to appease ruff

# Dictionary listing the methods/estimators to skip
# testing for a certain array api namespace

# use "all" to skip all testing for an estimator/method
# for a namespace (not just specific methods)
# (see the dask.array skip in LinearDiscriminantAnalysis for an example of this)
_array_api_skips = {
# Dask doesn't implement slogdet from the Array API
# which is used in score/score_samples
"PCA": {"dask.array": ["score", "score_samples"]},
# Lazy evaluation semantics: value-dependent shape item value (nan):
# "LinearDiscriminantAnalysis": {"dask.array": "all"},
}

_NUMPY_NAMESPACE_NAMES = {"numpy", "sklearn.externals.array_api_compat.numpy"}


Expand Down Expand Up @@ -50,6 +65,7 @@ def yield_namespaces(include_numpy_namespaces=True):
"array_api_strict",
"cupy",
"torch",
"dask.array",
]:
if not include_numpy_namespaces and array_namespace in _NUMPY_NAMESPACE_NAMES:
continue
Expand Down Expand Up @@ -543,6 +559,10 @@ def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False):
)

value = xp.asarray(value, dtype=array.dtype, device=device(array))

# TODO: candidate for upstreaming to array-api-extra
# This can't work in-place for dask/jax

end = None
# Explicit, fast formula for the common case. For 2-d arrays, we
# accept rectangular ones.
Expand All @@ -555,6 +575,10 @@ def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False):
array_flat[:end:step] += value
else:
array_flat[:end:step] = value
if is_dask_namespace(xp):
# hack to make sure correct value is returned for dask
return xp.reshape(array_flat, array.shape)
return array


def _is_xp_namespace(xp, name):
Expand Down Expand Up @@ -621,13 +645,15 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
weights = xp.asarray(weights, device=device_)

if weights is not None and a.shape != weights.shape:
if axis is None:
# shape checks are disabled here for dask
# TODO: figure out a strategy for doing this across sklearn
if axis is None and not is_lazy_array(a):
raise TypeError(
f"Axis must be specified when the shape of a {tuple(a.shape)} and "
f"weights {tuple(weights.shape)} differ."
)

if tuple(weights.shape) != (a.shape[axis],):
if tuple(weights.shape) != (a.shape[axis],) and not is_lazy_array(a):
raise ValueError(
f"Shape of weights weights.shape={tuple(weights.shape)} must be "
f"consistent with a.shape={tuple(a.shape)} and {axis=}."
Expand Down Expand Up @@ -910,6 +936,10 @@ def _in1d(ar1, ar2, xp, assume_unique=False, invert=False):
present in numpy:
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
"""
if is_dask_namespace(xp):
import dask.array as da

return da.isin(ar1, ar2, assume_unique, invert)
xp, _ = get_namespace(ar1, ar2, xp=xp)

# This code is run to make the code significantly faster
Expand Down
Loading
Loading