|
15 | 15 | from sklearn.decomposition._base import _BasePCA
|
16 | 16 | from sklearn.utils import check_random_state
|
17 | 17 | from sklearn.utils._arpack import _init_arpack_v0
|
18 |
| -from sklearn.utils._array_api import _convert_to_numpy, get_namespace |
| 18 | +from sklearn.utils._array_api import device, get_namespace |
19 | 19 | from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
|
20 |
| -from sklearn.utils.extmath import _randomized_svd, fast_logdet, stable_cumsum, svd_flip |
| 20 | +from sklearn.utils.extmath import _randomized_svd, fast_logdet, svd_flip |
21 | 21 | from sklearn.utils.sparsefuncs import _implicit_column_offset, mean_variance_axis
|
22 | 22 | from sklearn.utils.validation import check_is_fitted, validate_data
|
23 | 23 |
|
@@ -655,23 +655,15 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant):
|
655 | 655 | # side='right' ensures that number of features selected
|
656 | 656 | # their variance is always greater than n_components float
|
657 | 657 | # passed. More discussion in issue: #15669
|
658 |
| - if is_array_api_compliant: |
659 |
| - # Convert to numpy as xp.cumsum and xp.searchsorted are not |
660 |
| - # part of the Array API standard yet: |
661 |
| - # |
662 |
| - # https://github.com/data-apis/array-api/issues/597 |
663 |
| - # https://github.com/data-apis/array-api/issues/688 |
664 |
| - # |
665 |
| - # Furthermore, it's not always safe to call them for namespaces |
666 |
| - # that already implement them: for instance as |
667 |
| - # cupy.searchsorted does not accept a float as second argument. |
668 |
| - explained_variance_ratio_np = _convert_to_numpy( |
669 |
| - explained_variance_ratio_, xp=xp |
| 658 | + ratio_cumsum = xp.cumulative_sum(explained_variance_ratio_) |
| 659 | + n_components = ( |
| 660 | + xp.searchsorted( |
| 661 | + ratio_cumsum, |
| 662 | + xp.asarray(n_components, device=device(ratio_cumsum)), |
| 663 | + side="right", |
670 | 664 | )
|
671 |
| - else: |
672 |
| - explained_variance_ratio_np = explained_variance_ratio_ |
673 |
| - ratio_cumsum = stable_cumsum(explained_variance_ratio_np) |
674 |
| - n_components = np.searchsorted(ratio_cumsum, n_components, side="right") + 1 |
| 665 | + + 1 |
| 666 | + ) |
675 | 667 |
|
676 | 668 | # Compute noise covariance using Probabilistic PCA model
|
677 | 669 | # The sigma2 maximum likelihood (cf. eq. 12.46)
|
|
0 commit comments