-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH: Add Dask Array API support #28588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
sklearn/discriminant_analysis.py
Outdated
@@ -113,6 +113,14 @@ def _class_means(X, y): | |||
""" | |||
xp, is_array_api_compliant = get_namespace(X) | |||
classes, y = xp.unique_inverse(y) | |||
# Force lazy array api backends to call compute | |||
if hasattr(classes, "persist"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm planning on extracting the calls here to a compute
function to realize lazy arrays.
Not sure if it will live in scikit-learn, though. I think I'd want it to live in array-api-compat, but I haven't discussed this with the folks there yet.
Based on usage here, the new compute
function could have an option to e.g. compute shape only if that's what's needed, or compute the full array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this kind of code needs to go somewhere else. Having the "lazyness" leak through is a bit silly. I'm not sure where it should be though.
I don't know how to resolve this issue. In the past I've suggested that accessing something like .shape
should just trigger the computation for you or that everything needs to grow a .compute
, even eager libs. Both would allow array consumers like scikit-learn to write code that does not care whether the implementation is lazy or not. However there doesn't seem to be much support for this within the Array API community. The alternative is having to place code that checks "is this lazy? if yes trigger compute" in consumers like scikit-learn which I think is not great.
So this is maybe not a task for you alone to solve @lithomas1 but we need to find some kind of solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this some more, would scikit-learn be happy with a shape
helper (as opposed to my earlier suggestion of a compute
helper) from array-api-compat?
This would handle the laziness of the various array libraries out there and call compute
(or whatever the equivalent is) on the array if the shape needs to be materialized, before returning the shape of the array.
(In the event that array-api-compat is not installed, we can just define this to return, e.g. x.shape
like scikit-learn already does)
The advantage of this approach would be that scikit-learn doesn't have to think about laziness anymore - that work would be outsourced to array-api-compat.
The only downsides of this approach of this approach are that
- We indiscriminately materialize (for lazy arrays) even if it's strictly not necessary (e.g. we just access the shape to pass it to an array constructor like np.zeros). I don't think we'll lose too much (if any) performance here, though.
- scikit-learn devs need to remember that accessing
.shape
is banned, and existing usages have to be migrated.- This can be mitigated with a pre-commit hook, to automatically detect
.shape
accesses.
I think with something like ruff, one can write a custom rule to automatically rewrite.shape
accesses toshape(x)
- For existing usages, this is something that can be done as part of the Array API migration process, so it shouldn't cause too much churn on its own
- This can be mitigated with a pre-commit hook, to automatically detect
How does this sound to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@betatim any thoughts on the above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replying here and linking @ogrisel's comment #28588 (comment)
I think having a helper like shape()
is the only(?) way forward for now. I'd not add it to array-api-compat
but instead add it to scikit-learn in utils/_array_api.py
- we already have a few helpers there for things that feel "scikit-learn specific".
More adventurous: I wonder if we can even wrap the dask namespace (and via that its arrays) to make it so what .shape
access triggers the computation. That way people who edit scikit-learn's code base don't need to know anything about this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that a compute
helper could also help deal with the boolean-masked assignment problem in r2_score
described in more details in this comment: #28588 (comment)
This also a lazy evaluation problem but not related to shape values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having a helper like shape() is the only(?) way forward for now. I'd not add it to array-api-compat but instead add it to scikit-learn in utils/_array_api.py - we already have a few helpers there for things that feel "scikit-learn specific".
We could also do both: have a public helper in array-api-compat
and a private scikit-learn specific helper in scikit-learn, that does nothing for libraries that are not accessed via array-api-compat (e.g. array-api-strict
) as long the spec does not provide a standard way to deal with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More adventurous: I wonder if we can even wrap the dask namespace (and via that its arrays) to make it so what .shape access triggers the computation. That way people who edit scikit-learn's code base don't need to know anything about this issue.
Not sure how feasible this is and whether its desirable or not to trigger computation implicitly when using lazy libraries.
sklearn/discriminant_analysis.py
Outdated
# compute right now | ||
# Probably a dask bug | ||
# (the error is also kinda flaky) | ||
y = y.compute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To shed more light on what this is, I think this bug happens when scikit-learn calls unique_inverse.
I think something is going wrong in dask somewhere where the results of intermediate operations are getting corrupted.
When the error occurs in computation is somewhat flaky, but it happens more often than not without the compute here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you investigate the cause of the corruption? It might be worth reporting a minimal reproducer upstream.
While LDA was a bit hard to port over to dask, PCA worked perfectly out of the box! The other preprocessing/metrics/tools also worked out of the box, I think (at least judging by the tests). |
R.e. performance Testing out of core for dask generated by dask-ml with parameters
on a Gitpod machine with 2 cores and 8 GB RAM, I get 14m51s for 100,000,000 by 20 LDA (14.90 GB) # 100,000,000 samples which I think is a pretty decent scaling Distributed computation For
and I am measuring
|
FYI, array-api-compat fixes are ongoing here data-apis/array-api-compat#110 |
Since array-api-compat 1.5.1 came out and CI is green here, I'm going to be marking this PR as ready for review. The only other change I'm planning right now is splitting out the LDA changes, since that requires a patch to dask itself. The correct way to handle laziness is also something that might be good to think about. |
The score method is implicitly used by tools such as
It might also be a case the score method itself of the estimator can convert arrays to numpy if the namespace does not provide the necessary |
Co-authored-by: Samir Nasibli <samir.nasibli@intel.com>
sklearn/linear_model/_ridge.py
Outdated
@@ -288,11 +288,14 @@ def _solve_cholesky_kernel(K, y, alpha, sample_weight=None, copy=False): | |||
def _solve_svd(X, y, alpha, xp=None): | |||
xp, _ = get_namespace(X, xp=xp) | |||
U, s, Vt = xp.linalg.svd(X, full_matrices=False) | |||
idx = s > 1e-15 # same default value as scipy.linalg.pinv | |||
s_nnz = s[idx][:, None] | |||
idx = s > 1e-15[:, None] # same default value as scipy.linalg.pinv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>
does not have precedence of over []
, right? Trying this pattern locally yields:
TypeError: 'float' object is not subscriptable
I think you meant the following instead:
idx = s > 1e-15[:, None] # same default value as scipy.linalg.pinv | |
# scipy.linalg.pinv also thresholds at 1e-15 by default. | |
idx = (s > 1e-15)[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then the following s = s[:, None]
seems redundant and similarly for idx[:, None]
in the call to where
.
Also we should probably rename idx
to something more correct and explicit such as strictly_positive_mask
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think I had this working locally and passing tests, but I must have messed something up on the merge.
I'm planning on circling back to this towards the end of summer, so feel free to take over this if you're interested in the meantime.
Not holding my breath, but also hoping that dask.array support improves in the meantime as well.
(The recent Array API updates suggest that sort is the most pressing thing that's missing in dask. Linalg wise, the eig family of methods is probably the next biggest missing feature in dask, we just haven't seen it come up yet since not a lot of estimators have been ported yet).
Let's convert to draft for the time being then. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged main
to retrigger a round of CI with the current state of scikit-learn array API so that we have a better understanding on how much value dependent shapes / assignments are a problem and maybe compare the test results to the failures observed when attempting to run the same tests with jax: #29647.
I am changing my review state to request changes to avoid implying that this draft PR is ready to merge as is. We first need to agree if we accept partial dask support of if want to iron out all the problems induced by the lazy evaluation semantics before considering a merge of this PR.
Thanks for updating. I'm back to working on this again now that my summer is over. Last time I remember, I think had ridge regression working locally (after wrangling some issues with data-dependent output shapes which I think I worked around with where). I'm going to try to fix some more issues in dask upstream like the lack of sorting, which seems to be a majority of failures, an also linalg. r.e. dynamic shapes: |
42d97e2
to
b6af823
Compare
No I have not tried. |
For the record, #30340 was recently merged to |
Thanks, I'll rebase this this week and let you know how it goes. |
Made a bit of progress. TODOs
Changes
Blockers
Progress
Thoughts
|
…n into wip-dask-array-api
I haven't touched this PR in a while (besides rebasing). Hopefully I can make the PRs into array-api-extra and/or finish debugging LDA next week. |
there was a PR in progress for this at data-apis/array-api-extra#124, feel free to comment over there if that is blocking you |
Reference Issues/PRs
#26724
What does this implement/fix? Explain your changes.
Any other comments?
This depends on unmerged/unreleased changes in array-api-compat