-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Make standard scaler compatible to Array API #27113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Make standard scaler compatible to Array API #27113
Conversation
sklearn/utils/extmath.py
Outdated
result = op(x, *args, **kwargs, dtype=np.float64) | ||
from ..utils._array_api import isdtype, get_namespace | ||
xp, _ = get_namespace(x) | ||
if isdtype(x.dtype, "real floating", xp=xp) and x.dtype in (xp.float32, xp.float64): # what about int, etc.? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is actually an error, I guess the second condition should be x.dtype in (xp.float16, xp.float32)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since x.dtype.itemsize < 8
is not an Array API compatible way to assess the precision level of a dtype, we could instead do xp.finfo(x.dtype).bits < 64
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be solved
def test_standard_scaler_array_api_compliance(array_namespace, device, dtype): | ||
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype) | ||
|
||
from sklearn.datasets import make_classification |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should either be at the top of the file or we can just use something like random_state.randn(n_samples, n_features)
.
Hi @AlexanderFabisch, I'm happy to continue this if it cannot wait until October. Waiting to see what the maintainers think. :) Here are a few things I learned while working on my PR that might be helpful if you decide to keep working on it:
|
Sure, I could also give you write access to my fork if needed. That way we could collaborate better. |
3d9293a
to
fe6409c
Compare
Hi @AlexanderFabisch , thank you for sharing the fork :) There are still a couple of TODOs:
Another thing to bear in mind is that |
That looks a lot better @EdAbati . Thanks for continuing this PR. |
It used to be that _check_sample_weight would coerce to a maximum precision float dtype for non-floating-point data. It was changed to the default float type, though.
Also, fixup formatting.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
I started to rebase on the current main and will compile a list of todos. Unfortunately, this PR breaks a unit test that has been introduced recently. I hope to have a clear picture of what is left to do in the beginning of next week. |
dac8e3b
to
d965610
Compare
I rebased and cleaned up the PR a bit. I believe there is only one open discussion at the moment about using try/except vs. inspection of a function's signature in |
Here's my contribution from the EuroSciPy 2023 sprint. It's still work in progress and I won't have the time to continue the work before October. So if anyone else wants to take it from here, feel free to do so.
Reference Issues/PRs
See also #26024
What does this implement/fix? Explain your changes.
Make standard scaler compatible to Array API.
Any other comments?
Unfortunately, the current implementation breaks some unit tests of the standard scaler that are related to dtypes. That's because I wanted to make it work for torch.float16, but maybe that is not necessary and we should just support float32 and float64.
I'll also add some comments to the diff. See below.