Skip to content

ENH: Add Dask Array API support #28588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 45 commits into
base: main
Choose a base branch
from

Conversation

lithomas1
Copy link
Contributor

@lithomas1 lithomas1 commented Mar 7, 2024

Reference Issues/PRs

#26724

What does this implement/fix? Explain your changes.

Any other comments?

This depends on unmerged/unreleased changes in array-api-compat

Copy link

github-actions bot commented Mar 7, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: bd5c6e4. Link to the linter CI: here

@@ -113,6 +113,14 @@ def _class_means(X, y):
"""
xp, is_array_api_compliant = get_namespace(X)
classes, y = xp.unique_inverse(y)
# Force lazy array api backends to call compute
if hasattr(classes, "persist"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm planning on extracting the calls here to a compute function to realize lazy arrays.

Not sure if it will live in scikit-learn, though. I think I'd want it to live in array-api-compat, but I haven't discussed this with the folks there yet.

Based on usage here, the new compute function could have an option to e.g. compute shape only if that's what's needed, or compute the full array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this kind of code needs to go somewhere else. Having the "lazyness" leak through is a bit silly. I'm not sure where it should be though.

I don't know how to resolve this issue. In the past I've suggested that accessing something like .shape should just trigger the computation for you or that everything needs to grow a .compute, even eager libs. Both would allow array consumers like scikit-learn to write code that does not care whether the implementation is lazy or not. However there doesn't seem to be much support for this within the Array API community. The alternative is having to place code that checks "is this lazy? if yes trigger compute" in consumers like scikit-learn which I think is not great.

So this is maybe not a task for you alone to solve @lithomas1 but we need to find some kind of solution

Copy link
Contributor Author

@lithomas1 lithomas1 Mar 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this some more, would scikit-learn be happy with a shape helper (as opposed to my earlier suggestion of a compute helper) from array-api-compat?

This would handle the laziness of the various array libraries out there and call compute (or whatever the equivalent is) on the array if the shape needs to be materialized, before returning the shape of the array.

(In the event that array-api-compat is not installed, we can just define this to return, e.g. x.shape like scikit-learn already does)

The advantage of this approach would be that scikit-learn doesn't have to think about laziness anymore - that work would be outsourced to array-api-compat.

The only downsides of this approach of this approach are that

  • We indiscriminately materialize (for lazy arrays) even if it's strictly not necessary (e.g. we just access the shape to pass it to an array constructor like np.zeros). I don't think we'll lose too much (if any) performance here, though.
  • scikit-learn devs need to remember that accessing .shape is banned, and existing usages have to be migrated.
    • This can be mitigated with a pre-commit hook, to automatically detect .shape accesses.
      I think with something like ruff, one can write a custom rule to automatically rewrite .shape accesses to shape(x)
    • For existing usages, this is something that can be done as part of the Array API migration process, so it shouldn't cause too much churn on its own

How does this sound to you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@betatim any thoughts on the above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replying here and linking @ogrisel's comment #28588 (comment)

I think having a helper like shape() is the only(?) way forward for now. I'd not add it to array-api-compat but instead add it to scikit-learn in utils/_array_api.py - we already have a few helpers there for things that feel "scikit-learn specific".

More adventurous: I wonder if we can even wrap the dask namespace (and via that its arrays) to make it so what .shape access triggers the computation. That way people who edit scikit-learn's code base don't need to know anything about this issue.

Copy link
Member

@ogrisel ogrisel Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that a compute helper could also help deal with the boolean-masked assignment problem in r2_score described in more details in this comment: #28588 (comment)

This also a lazy evaluation problem but not related to shape values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having a helper like shape() is the only(?) way forward for now. I'd not add it to array-api-compat but instead add it to scikit-learn in utils/_array_api.py - we already have a few helpers there for things that feel "scikit-learn specific".

We could also do both: have a public helper in array-api-compat and a private scikit-learn specific helper in scikit-learn, that does nothing for libraries that are not accessed via array-api-compat (e.g. array-api-strict) as long the spec does not provide a standard way to deal with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More adventurous: I wonder if we can even wrap the dask namespace (and via that its arrays) to make it so what .shape access triggers the computation. That way people who edit scikit-learn's code base don't need to know anything about this issue.

Not sure how feasible this is and whether its desirable or not to trigger computation implicitly when using lazy libraries.

# compute right now
# Probably a dask bug
# (the error is also kinda flaky)
y = y.compute()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To shed more light on what this is, I think this bug happens when scikit-learn calls unique_inverse.

I think something is going wrong in dask somewhere where the results of intermediate operations are getting corrupted.

When the error occurs in computation is somewhat flaky, but it happens more often than not without the compute here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you investigate the cause of the corruption? It might be worth reporting a minimal reproducer upstream.

@lithomas1
Copy link
Contributor Author

While LDA was a bit hard to port over to dask, PCA worked perfectly out of the box!
(which I think is a pretty big win for the array API)

The other preprocessing/metrics/tools also worked out of the box, I think (at least judging by the tests).

@lithomas1
Copy link
Contributor Author

lithomas1 commented Mar 7, 2024

R.e. performance

Testing out of core

for dask generated by dask-ml with parameters

n_samples=100_000
n_classes=2
n_informative=5

on a Gitpod machine with 2 cores and 8 GB RAM, I get

14m51s for 100,000,000 by 20 LDA (14.90 GB) # 100,000,000 samples
47 seconds for 10,000,000 by 20 LDA (1.49 GB) # 10,000,000 samples
6m 3.5s for 50,000,000 by 20 LDA (7.45 GB). # 50,000,000 samples

which I think is a pretty decent scaling

Distributed computation

For

n_samples=20,000,000
n_classes=2
n_informative=5

and chunksize=100,000

I am measuring
45s runtime for 4 workers (2 CPU, 8GB RAM), and
4 min 29s runtime for a single worker

  • note: there was a bit of spilling on this one.

@lithomas1 lithomas1 changed the title Add Dask Array API support to LDA/PCA ENH: Add Dask Array API support to LDA/PCA Mar 7, 2024
@adrinjalali
Copy link
Member

cc @betatim @ogrisel

@lithomas1
Copy link
Contributor Author

FYI, array-api-compat fixes are ongoing here data-apis/array-api-compat#110

@lithomas1 lithomas1 marked this pull request as ready for review March 24, 2024 21:41
@lithomas1
Copy link
Contributor Author

Since array-api-compat 1.5.1 came out and CI is green here,

I'm going to be marking this PR as ready for review.

The only other change I'm planning right now is splitting out the LDA changes, since that requires a patch to dask itself.

The correct way to handle laziness is also something that might be good to think about.
(It might be good to loop in more scikit-learn devs about this).

@lithomas1 lithomas1 changed the title ENH: Add Dask Array API support to LDA/PCA ENH: Add Dask Array API support Mar 26, 2024
@ogrisel
Copy link
Member

ogrisel commented Apr 23, 2024

I'm not sure how important score is though. There seem to be some usages of it, but curiously there seem to be no examples in scikit-learn itself on using score.

The score method is implicitly used by tools such as cross_val_score or GridSearchCV. However, it is true that it is very rarely used for PCA alone in practice. It's mostly used for supervised learning pipelines.

I agree that it's not ideal to have dask not be able to use score, but this is something that I think is reasonable to have users work around for now, e.g. by using _estimator_with_converted_arrays, to convert the estimator from dask to numpy arrays.
(similar to how we transfer arrays from GPU to CPU on cupy)

It might also be a case the score method itself of the estimator can convert arrays to numpy if the namespace does not provide the necessary xp.linalg.slogdet method. For truncated PCA, we can expect the call of the fit method and maybe get_precision to be slow and would deserve running on accelerated namespace while the final xp.linalg.slogdet call should not be a performance critical operation (and the result is a scalar).

Co-authored-by: Samir Nasibli <samir.nasibli@intel.com>
@@ -288,11 +288,14 @@ def _solve_cholesky_kernel(K, y, alpha, sample_weight=None, copy=False):
def _solve_svd(X, y, alpha, xp=None):
xp, _ = get_namespace(X, xp=xp)
U, s, Vt = xp.linalg.svd(X, full_matrices=False)
idx = s > 1e-15 # same default value as scipy.linalg.pinv
s_nnz = s[idx][:, None]
idx = s > 1e-15[:, None] # same default value as scipy.linalg.pinv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

> does not have precedence of over [], right? Trying this pattern locally yields:

TypeError: 'float' object is not subscriptable

I think you meant the following instead:

Suggested change
idx = s > 1e-15[:, None] # same default value as scipy.linalg.pinv
# scipy.linalg.pinv also thresholds at 1e-15 by default.
idx = (s > 1e-15)[:, None]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then the following s = s[:, None] seems redundant and similarly for idx[:, None] in the call to where.

Also we should probably rename idx to something more correct and explicit such as strictly_positive_mask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think I had this working locally and passing tests, but I must have messed something up on the merge.

I'm planning on circling back to this towards the end of summer, so feel free to take over this if you're interested in the meantime.

Not holding my breath, but also hoping that dask.array support improves in the meantime as well.

(The recent Array API updates suggest that sort is the most pressing thing that's missing in dask. Linalg wise, the eig family of methods is probably the next biggest missing feature in dask, we just haven't seen it come up yet since not a lot of estimators have been ported yet).

@ogrisel ogrisel marked this pull request as draft May 24, 2024 13:59
@ogrisel
Copy link
Member

ogrisel commented May 24, 2024

Let's convert to draft for the time being then.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged main to retrigger a round of CI with the current state of scikit-learn array API so that we have a better understanding on how much value dependent shapes / assignments are a problem and maybe compare the test results to the failures observed when attempting to run the same tests with jax: #29647.

I am changing my review state to request changes to avoid implying that this draft PR is ready to merge as is. We first need to agree if we accept partial dask support of if want to iron out all the problems induced by the lazy evaluation semantics before considering a merge of this PR.

@lithomas1
Copy link
Contributor Author

Thanks for updating.

I'm back to working on this again now that my summer is over.

Last time I remember, I think had ridge regression working locally (after wrangling some issues with data-dependent output shapes which I think I worked around with where).

I'm going to try to fix some more issues in dask upstream like the lack of sorting, which seems to be a majority of failures, an also linalg.

r.e. dynamic shapes:
Have you tried running sklearn functions through jax with JIT on?
I believe dynamic shapes isn't an issue in jax without the JIT.

@ogrisel
Copy link
Member

ogrisel commented Sep 11, 2024

Have you tried running sklearn functions through jax with JIT on?

No I have not tried.

@ogrisel
Copy link
Member

ogrisel commented Mar 25, 2025

For the record, #30340 was recently merged to main so this might unblock this PR because we can now leverage dask compat features implemented in array-api-extra.

@lithomas1
Copy link
Contributor Author

Thanks, I'll rebase this this week and let you know how it goes.

@lithomas1
Copy link
Contributor Author

lithomas1 commented Mar 30, 2025

Made a bit of progress.
Some notes I made along the way are:

TODOs

  • Upstream fill_diagonal to array-api-extra (5 uses in sklearn)
  • Upstream isin/in1d to array-api-extra (used in encoders)

Changes

  • had to disable shape check in _average for lazy backends

Blockers

  • setdiff1d in array-api-extra only works on eager arrays
    • Probably the next major thing to work on (this blocks LabelEncoder from working)
      • remaining 13 metrics tests also depend on this

Progress

  • Fixed PCA (was broken due to fill_diagonal)
  • Fixed the input validators (e.g. is_multilabel)
    • nunique in array-api-extra is very helpful so far
      • Able to do int(nunique(... to get n_classes
      • This forces compute (but shouldn't be so bad if we're only doing this on y)
    • Going to try to apply this to LDA later on

Thoughts

  • nunique works, but it'd be better if it was part of a unique function though
    • Right now, we have to do double the unique computation, once on finding the unique values and once for nunique
  • If we want to support lazy backends, there needs to be some thought as to how error checking should work
    • Right now, I disable them since for dask we don't know the shape until we compute. The alternative would be to defer them to raise during compute time.

@lithomas1
Copy link
Contributor Author

I haven't touched this PR in a while (besides rebasing).
I've opened an issue at array-api-extra, data-apis/array-api-extra#268, for the missing functions and I think they're in favor of adding them.

Hopefully I can make the PRs into array-api-extra and/or finish debugging LDA next week.

@lucascolley
Copy link
Contributor

setdiff1d in array-api-extra only works on eager arrays

there was a PR in progress for this at data-apis/array-api-extra#124, feel free to comment over there if that is blocking you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

6 participants