Skip to content

Commit 450cb20

Browse files
authored
ENH use xp.cumulative_sum and xp.searchsorted directly instead of stable_cumsum (#31994)
1 parent f19ff9c commit 450cb20

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

sklearn/decomposition/_pca.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from sklearn.decomposition._base import _BasePCA
1616
from sklearn.utils import check_random_state
1717
from sklearn.utils._arpack import _init_arpack_v0
18-
from sklearn.utils._array_api import _convert_to_numpy, get_namespace
18+
from sklearn.utils._array_api import device, get_namespace
1919
from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
20-
from sklearn.utils.extmath import _randomized_svd, fast_logdet, stable_cumsum, svd_flip
20+
from sklearn.utils.extmath import _randomized_svd, fast_logdet, svd_flip
2121
from sklearn.utils.sparsefuncs import _implicit_column_offset, mean_variance_axis
2222
from sklearn.utils.validation import check_is_fitted, validate_data
2323

@@ -655,23 +655,15 @@ def _fit_full(self, X, n_components, xp, is_array_api_compliant):
655655
# side='right' ensures that number of features selected
656656
# their variance is always greater than n_components float
657657
# passed. More discussion in issue: #15669
658-
if is_array_api_compliant:
659-
# Convert to numpy as xp.cumsum and xp.searchsorted are not
660-
# part of the Array API standard yet:
661-
#
662-
# https://github.com/data-apis/array-api/issues/597
663-
# https://github.com/data-apis/array-api/issues/688
664-
#
665-
# Furthermore, it's not always safe to call them for namespaces
666-
# that already implement them: for instance as
667-
# cupy.searchsorted does not accept a float as second argument.
668-
explained_variance_ratio_np = _convert_to_numpy(
669-
explained_variance_ratio_, xp=xp
658+
ratio_cumsum = xp.cumulative_sum(explained_variance_ratio_)
659+
n_components = (
660+
xp.searchsorted(
661+
ratio_cumsum,
662+
xp.asarray(n_components, device=device(ratio_cumsum)),
663+
side="right",
670664
)
671-
else:
672-
explained_variance_ratio_np = explained_variance_ratio_
673-
ratio_cumsum = stable_cumsum(explained_variance_ratio_np)
674-
n_components = np.searchsorted(ratio_cumsum, n_components, side="right") + 1
665+
+ 1
666+
)
675667

676668
# Compute noise covariance using Probabilistic PCA model
677669
# The sigma2 maximum likelihood (cf. eq. 12.46)

0 commit comments

Comments
 (0)