Skip to content

Make standard scaler compatible to Array API #27113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 60 commits into
base: main
Choose a base branch
from

Conversation

AlexanderFabisch
Copy link
Member

Here's my contribution from the EuroSciPy 2023 sprint. It's still work in progress and I won't have the time to continue the work before October. So if anyone else wants to take it from here, feel free to do so.

Reference Issues/PRs

See also #26024

What does this implement/fix? Explain your changes.

Make standard scaler compatible to Array API.

Any other comments?

Unfortunately, the current implementation breaks some unit tests of the standard scaler that are related to dtypes. That's because I wanted to make it work for torch.float16, but maybe that is not necessary and we should just support float32 and float64.

I'll also add some comments to the diff. See below.

@github-actions
Copy link

github-actions bot commented Aug 19, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 4b5e4bd. Link to the linter CI: here

result = op(x, *args, **kwargs, dtype=np.float64)
from ..utils._array_api import isdtype, get_namespace
xp, _ = get_namespace(x)
if isdtype(x.dtype, "real floating", xp=xp) and x.dtype in (xp.float32, xp.float64): # what about int, etc.?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is actually an error, I guess the second condition should be x.dtype in (xp.float16, xp.float32).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since x.dtype.itemsize < 8 is not an Array API compatible way to assess the precision level of a dtype, we could instead do xp.finfo(x.dtype).bits < 64.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be solved

def test_standard_scaler_array_api_compliance(array_namespace, device, dtype):
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)

from sklearn.datasets import make_classification
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should either be at the top of the file or we can just use something like random_state.randn(n_samples, n_features).

@AlexanderFabisch AlexanderFabisch changed the title [WIP] Make standard scaler compatible to Array API WIP: Make standard scaler compatible to Array API Aug 19, 2023
@AlexanderFabisch AlexanderFabisch marked this pull request as draft August 19, 2023 17:07
@EdAbati
Copy link
Contributor

EdAbati commented Aug 20, 2023

Hi @AlexanderFabisch, I'm happy to continue this if it cannot wait until October. Waiting to see what the maintainers think. :)

Here are a few things I learned while working on my PR that might be helpful if you decide to keep working on it:

  • update your branch with main to get some useful functions like _array_api.supported_float_dtypes
  • testing the Array API compliance could be done by using a function that looks like this
  • in other places, a scalar array is created using xp.asarray(0.0, device=device(...))

@AlexanderFabisch
Copy link
Member Author

I'm happy to continue this if it cannot wait until October.

Sure, I could also give you write access to my fork if needed. That way we could collaborate better.

@EdAbati
Copy link
Contributor

EdAbati commented Sep 16, 2023

Hi @AlexanderFabisch , thank you for sharing the fork :)
I continued a bit, and tried to resolve some comments based on what I saw in the other PRs.

There are still a couple of TODOs:

Another thing to bear in mind is that device='mps' does not support float64. #27232 introduces something we could use

@AlexanderFabisch
Copy link
Member Author

That looks a lot better @EdAbati . Thanks for continuing this PR.

@AlexanderFabisch
Copy link
Member Author

I started to rebase on the current main and will compile a list of todos. Unfortunately, this PR breaks a unit test that has been introduced recently. I hope to have a clear picture of what is left to do in the beginning of next week.

@AlexanderFabisch AlexanderFabisch force-pushed the feature/standard_scaler_array_api branch from dac8e3b to d965610 Compare February 15, 2025 16:13
@AlexanderFabisch AlexanderFabisch marked this pull request as ready for review February 15, 2025 16:53
@AlexanderFabisch
Copy link
Member Author

I rebased and cleaned up the PR a bit. I believe there is only one open discussion at the moment about using try/except vs. inspection of a function's signature in extmath._safe_accumulator_op. Everything else looks good. @charlesjhill What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

8 participants