Skip to content

Conversation

lesteve
Copy link
Member

@lesteve lesteve commented Feb 6, 2025

Working on it with @StefanieSenger.

Link to TODO

@lesteve lesteve marked this pull request as draft February 6, 2025 14:26
Copy link

github-actions bot commented Feb 6, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: d46840b. Link to the linter CI: here

@StefanieSenger StefanieSenger self-requested a review February 14, 2025 09:28
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. I left a few comments



# TODO What is the expected behavior when weights init
# and X are not in the same namespace/device?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not resolved yet. Can we remove the commented out code?

@OmarManzoor
Copy link
Contributor

@lesteve Just one test failing and that has to do with array api strict on device and float32. Maybe we need to increase the tolerance further for this specific scenario.

@lesteve
Copy link
Member Author

lesteve commented Jun 19, 2025

My honest impression is that these tests are fragile on float32 data but I don't really know if there is much we can do to improve the situation ...

Even for array-api-strict the results are different because of the difference between scipy.linalg.choleksy and numpy.linalg.cholesky and between scipy.linalg.triangular_solve and numpy.linalg.solve.

On a GPU VM I also saw some test failures (a few more than in the CI actually) and raised the atol and rtol a bit to get them to pass locally. I trigger another run of the CUDA CI, let's see what happens 🤞.

@OmarManzoor
Copy link
Contributor

I don't think we can do much with trying to improve array-api-strict tests for float32 especially with respect to accuracy. As long as array-api-strict works generally I think that should be sufficient.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for the work done in this PR @lesteve and @StefanieSenger

@OmarManzoor OmarManzoor merged commit cc526ee into scikit-learn:main Jun 19, 2025
40 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Array API Jun 19, 2025
@lesteve lesteve deleted the gmm-array-api branch June 19, 2025 12:18
@lesteve
Copy link
Member Author

lesteve commented Jun 20, 2025

Thanks for the reviews @OmarManzoor and @ogrisel!

One of the remaining question in the old and long TODO list: should we implement __sklearn_tags__ to tell that GaussianMixture has array_api_support?

PCA ___sklearn_tags__ does this currently and always sets array_api = True although array API support is implemented for some values of the parameters, not sure whether this is expected or not:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
tags.array_api_support = True
tags.input_tags.sparse = self.svd_solver in (
"auto",
"arpack",
"covariance_eigh",
)
return tags

I am guessing the array_api tags is only used for the common tests right now, right?

@ogrisel
Copy link
Member

ogrisel commented Jun 20, 2025

Good questions:

  • indeed, we could make PCA only return the tags.array_api_support = True when the solver supports array API inputs.
  • similarly for GaussianMixture (depending on the choice of the init).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants