From eceb63ebd040dd58a327aba046b9e5043039cc4a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 21 Sep 2023 14:46:33 +0200 Subject: [PATCH 1/9] FIX array_api support for non-integer n_components in PCA --- sklearn/decomposition/_pca.py | 17 +++++++++-------- sklearn/decomposition/tests/test_pca.py | 2 +- sklearn/utils/extmath.py | 14 +++++++++++--- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 877baf4d4e81c..d84ec99b65aae 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -59,6 +59,7 @@ def _assess_dimension(spectrum, rank, n_samples): Automatic Choice of Dimensionality for PCA. NIPS 2000: 598-604 `_ """ + xp, _ = get_namespace(spectrum) n_features = spectrum.shape[0] if not 1 <= rank < n_features: @@ -72,29 +73,29 @@ def _assess_dimension(spectrum, rank, n_samples): # small and won't be the max anyway. Also, it can lead to numerical # issues below when computing pa, in particular in log((spectrum[i] - # spectrum[j]) because this will take the log of something very small. - return -np.inf + return -xp.inf pu = -rank * log(2.0) for i in range(1, rank + 1): pu += ( gammaln((n_features - i + 1) / 2.0) - - log(np.pi) * (n_features - i + 1) / 2.0 + - log(xp.pi) * (n_features - i + 1) / 2.0 ) - pl = np.sum(np.log(spectrum[:rank])) + pl = xp.sum(xp.log(spectrum[:rank])) pl = -pl * n_samples / 2.0 - v = max(eps, np.sum(spectrum[rank:]) / (n_features - rank)) - pv = -np.log(v) * n_samples * (n_features - rank) / 2.0 + v = max(eps, xp.sum(spectrum[rank:]) / (n_features - rank)) + pv = -log(v) * n_samples * (n_features - rank) / 2.0 m = n_features * rank - rank * (rank + 1.0) / 2.0 - pp = log(2.0 * np.pi) * (m + rank) / 2.0 + pp = log(2.0 * xp.pi) * (m + rank) / 2.0 pa = 0.0 - spectrum_ = spectrum.copy() + spectrum_ = xp.asarray(spectrum, copy=True) spectrum_[rank:n_features] = v for i in range(rank): - for j in range(i + 1, len(spectrum)): + for j in range(i + 1, spectrum.shape[0]): pa += log( (spectrum[i] - spectrum[j]) * (1.0 / spectrum_[j] - 1.0 / spectrum_[i]) ) + log(n_samples) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index d6e2516aef2f2..90c0b68dbe772 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -742,7 +742,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp "estimator", [ PCA(n_components=2, svd_solver="full"), - PCA(n_components=2, svd_solver="full", whiten=True), + PCA(n_components=0.1, svd_solver="full", whiten=True), PCA( n_components=2, svd_solver="randomized", diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index c2aa9d07e6635..e26688ce8710d 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -1210,11 +1210,19 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): """ xp, _ = get_namespace(arr) - out = xp.cumsum(arr, axis=axis, dtype=np.float64) - expected = xp.sum(arr, axis=axis, dtype=np.float64) + if axis is None: + arr = xp.reshape(arr, (-1,)) + axis = 0 + + out = xp.cumsum(arr, axis=axis, dtype=xp.float64) + expected = xp.sum(arr, axis=axis, dtype=xp.float64) if not xp.all( xp.isclose( - out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True + xp.take(out, xp.asarray([out.shape[0] - 1]), axis=axis), + expected, + rtol=rtol, + atol=atol, + equal_nan=True, ) ): warnings.warn( From cc4784b15f13dfdef13bd9913ba030f7fcfd63f2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 21 Sep 2023 16:09:04 +0200 Subject: [PATCH 2/9] TST test (and fix) for n_components='mle' --- sklearn/decomposition/_pca.py | 2 +- sklearn/decomposition/tests/test_pca.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index d84ec99b65aae..71aa0305ef26b 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -116,7 +116,7 @@ def _infer_dimension(spectrum, n_samples): ll[0] = -xp.inf # we don't want to return n_components = 0 for rank in range(1, spectrum.shape[0]): ll[rank] = _assess_dimension(spectrum, rank, n_samples) - return ll.argmax() + return xp.argmax(ll) class PCA(_BasePCA): diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 90c0b68dbe772..0ceb913e33587 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -19,6 +19,7 @@ from sklearn.utils._testing import _array_api_for_tests, assert_allclose from sklearn.utils.estimator_checks import ( _get_check_estimator_ids, + check_array_api_input, check_array_api_input_and_values, ) from sklearn.utils.fixes import CSR_CONTAINERS @@ -757,6 +758,28 @@ def test_pca_array_api_compliance(estimator, check, array_namepsace, device, dty check(name, estimator, array_namepsace, device=device, dtype=dtype) +@pytest.mark.parametrize( + "array_namepsace, device, dtype", yield_namespace_device_dtype_combinations() +) +@pytest.mark.parametrize( + "check", + [check_array_api_input, check_array_api_get_precision], + ids=_get_check_estimator_ids, +) +@pytest.mark.parametrize( + "estimator", + [ + # PCA with mle cannot use check_array_api_input_and_values becayse of + # rounding errors in the noisy (low variance) components. + PCA(n_components="mle", svd_solver="full"), + ], + ids=_get_check_estimator_ids, +) +def test_pca_mle_array_api_compliance(estimator, check, array_namepsace, device, dtype): + name = estimator.__class__.__name__ + check(name, estimator, array_namepsace, device=device, dtype=dtype) + + def test_array_api_error_and_warnings_on_unsupported_params(): pytest.importorskip("array_api_compat") xp = pytest.importorskip("numpy.array_api") From 0c419168814de1884be493e38ed9cb8d9f3ed8e5 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 21 Sep 2023 16:14:16 +0200 Subject: [PATCH 3/9] Typo --- sklearn/decomposition/tests/test_pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 0ceb913e33587..0528c0a79af33 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -769,7 +769,7 @@ def test_pca_array_api_compliance(estimator, check, array_namepsace, device, dty @pytest.mark.parametrize( "estimator", [ - # PCA with mle cannot use check_array_api_input_and_values becayse of + # PCA with mle cannot use check_array_api_input_and_values because of # rounding errors in the noisy (low variance) components. PCA(n_components="mle", svd_solver="full"), ], From 3569edcbab7af5fbd5d0cf428ae0dbeb3b87bc15 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 27 Sep 2023 19:17:23 +0200 Subject: [PATCH 4/9] Do not make stable_cumsum pretend to work with array api as some backends do not support float64 --- sklearn/decomposition/_pca.py | 9 ++++++++- sklearn/decomposition/tests/test_pca.py | 21 ++++++++++++++++++--- sklearn/utils/extmath.py | 18 ++++++------------ 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 71aa0305ef26b..abc0ca7b24509 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -574,7 +574,14 @@ def _fit_full(self, X, n_components): # side='right' ensures that number of features selected # their variance is always greater than n_components float # passed. More discussion in issue: #15669 - ratio_cumsum = stable_cumsum(explained_variance_ratio_) + if is_array_api_compliant: + ratio_cumsum = xp.cumsum(explained_variance_ratio_, axis=0) + else: + # Backward compat type for traditional numpy. Note that this + # stable_cumsum function is probably not necessary. A direct + # call to np.cumsum should be stable enough, without even + # casting to float64. + ratio_cumsum = stable_cumsum(explained_variance_ratio_) n_components = xp.searchsorted(ratio_cumsum, n_components, side="right") + 1 # Compute noise covariance using Probabilistic PCA model # The sigma2 maximum likelihood (cf. eq. 12.46) diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 0528c0a79af33..a3c39ddb17926 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -732,7 +732,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp @pytest.mark.parametrize( - "array_namepsace, device, dtype", yield_namespace_device_dtype_combinations() + "array_namespace, device, dtype", yield_namespace_device_dtype_combinations() ) @pytest.mark.parametrize( "check", @@ -753,9 +753,24 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp ], ids=_get_check_estimator_ids, ) -def test_pca_array_api_compliance(estimator, check, array_namepsace, device, dtype): +def test_pca_array_api_compliance(estimator, check, array_namespace, device, dtype): name = estimator.__class__.__name__ - check(name, estimator, array_namepsace, device=device, dtype=dtype) + xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype) + if not isinstance(estimator.n_components, int) and not hasattr(xp, "cumsum"): + # Our code anticipates the implementation of xp.cumsum that should be + # standardized at some point, see: + # https://github.com/data-apis/array-api/issues/597 + pytest.xfail( + f"Array API namespace {array_namespace} does not support cumsum yet." + ) + if not isinstance(estimator.n_components, int) and not hasattr(xp, "searchsorted"): + # Our code anticipates the implementation of xp.searchsorted that + # should be standardized at some point, see: + # https://github.com/data-apis/array-api/issues/688 + pytest.xfail( + f"Array API namespace {array_namespace} does not support searchsorted yet." + ) + check(name, estimator, array_namespace, device=device, dtype=dtype) @pytest.mark.parametrize( diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index e26688ce8710d..af0fbc03fd5e1 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -1208,21 +1208,15 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): out : ndarray Array with the cumulative sums along the chosen axis. """ - xp, _ = get_namespace(arr) - if axis is None: - arr = xp.reshape(arr, (-1,)) + arr = arr.ravel() axis = 0 - out = xp.cumsum(arr, axis=axis, dtype=xp.float64) - expected = xp.sum(arr, axis=axis, dtype=xp.float64) - if not xp.all( - xp.isclose( - xp.take(out, xp.asarray([out.shape[0] - 1]), axis=axis), - expected, - rtol=rtol, - atol=atol, - equal_nan=True, + out = np.cumsum(arr, axis=axis, dtype=np.float64) + expected = np.sum(arr, axis=axis, dtype=np.float64) + if not np.all( + np.isclose( + out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True ) ): warnings.warn( From 209fa47f90d87cd25dec94bcce894e94871ac680 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 27 Sep 2023 19:52:29 +0200 Subject: [PATCH 5/9] Revert to a simpler version of numpy-specific stable_cumsum --- sklearn/utils/extmath.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index af0fbc03fd5e1..bcbd6a61bf346 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -1208,10 +1208,6 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): out : ndarray Array with the cumulative sums along the chosen axis. """ - if axis is None: - arr = arr.ravel() - axis = 0 - out = np.cumsum(arr, axis=axis, dtype=np.float64) expected = np.sum(arr, axis=axis, dtype=np.float64) if not np.all( From 34ff427f4538179b4cf563632d5e66c12fc3deb4 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 27 Sep 2023 19:56:08 +0200 Subject: [PATCH 6/9] Even simpler version of numpy-specific stable_cumsum --- sklearn/utils/extmath.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index bcbd6a61bf346..deebf0f126eed 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -1210,10 +1210,8 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): """ out = np.cumsum(arr, axis=axis, dtype=np.float64) expected = np.sum(arr, axis=axis, dtype=np.float64) - if not np.all( - np.isclose( - out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True - ) + if not np.allclose( + out.take(-1, axis=axis), expected, rtol=rtol, atol=atol, equal_nan=True ): warnings.warn( ( From ff3a93cceb062a9021f07fade375609a7373aa1f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 27 Sep 2023 19:57:20 +0200 Subject: [PATCH 7/9] Document change in the changelog --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 0778681308b6a..cb68931394f5c 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -164,7 +164,7 @@ Changelog - |Enhancement| :class:`decomposition.PCA` now supports the Array API for the `full` and `randomized` solvers (with QR power iterations). See :ref:`array_api` for more details. - :pr:`26315` and :pr:`27098` by :user:`Mateusz Sokół `, + :pr:`26315`, :pr:`27098` and :pr:`27431` by :user:`Mateusz Sokół `, :user:`Olivier Grisel ` and :user:`Edoardo Abati `. :mod:`sklearn.ensemble` From 723d19d356d69cbfc9bae80c108daee35e68916e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 09:28:48 +0200 Subject: [PATCH 8/9] Always perform cumsum and searchsorted using numpy for now --- sklearn/decomposition/_pca.py | 30 ++++++++++++++++++------- sklearn/decomposition/tests/test_pca.py | 15 ------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index abc0ca7b24509..bd9b5325e38c8 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -22,7 +22,7 @@ from ..base import _fit_context from ..utils import check_random_state from ..utils._arpack import _init_arpack_v0 -from ..utils._array_api import get_namespace +from ..utils._array_api import _convert_to_numpy, get_namespace from ..utils._param_validation import Interval, RealNotInt, StrOptions from ..utils.deprecation import deprecated from ..utils.extmath import fast_logdet, randomized_svd, stable_cumsum, svd_flip @@ -575,14 +575,28 @@ def _fit_full(self, X, n_components): # their variance is always greater than n_components float # passed. More discussion in issue: #15669 if is_array_api_compliant: - ratio_cumsum = xp.cumsum(explained_variance_ratio_, axis=0) + # Convert to numpy as xp.cumsum and xp.searchsorted are not + # part of the Array API standard yet: + # + # https://github.com/data-apis/array-api/issues/597 + # https://github.com/data-apis/array-api/issues/688 + # + # Furthermore, it's not always safe to call them for namespaces + # that already implement them: for instance as + # cupy.searchsorted does not accept a float as second argument. + explained_variance_ratio_np = _convert_to_numpy( + explained_variance_ratio_, xp=xp + ) else: - # Backward compat type for traditional numpy. Note that this - # stable_cumsum function is probably not necessary. A direct - # call to np.cumsum should be stable enough, without even - # casting to float64. - ratio_cumsum = stable_cumsum(explained_variance_ratio_) - n_components = xp.searchsorted(ratio_cumsum, n_components, side="right") + 1 + explained_variance_ratio_np = explained_variance_ratio_ + n_components = ( + np.searchsorted( + stable_cumsum(explained_variance_ratio_np), + n_components, + side="right", + ) + + 1 + ) # Compute noise covariance using Probabilistic PCA model # The sigma2 maximum likelihood (cf. eq. 12.46) if n_components < min(n_features, n_samples): diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index a3c39ddb17926..ce2d42d09e8ae 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -755,21 +755,6 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp ) def test_pca_array_api_compliance(estimator, check, array_namespace, device, dtype): name = estimator.__class__.__name__ - xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype) - if not isinstance(estimator.n_components, int) and not hasattr(xp, "cumsum"): - # Our code anticipates the implementation of xp.cumsum that should be - # standardized at some point, see: - # https://github.com/data-apis/array-api/issues/597 - pytest.xfail( - f"Array API namespace {array_namespace} does not support cumsum yet." - ) - if not isinstance(estimator.n_components, int) and not hasattr(xp, "searchsorted"): - # Our code anticipates the implementation of xp.searchsorted that - # should be standardized at some point, see: - # https://github.com/data-apis/array-api/issues/688 - pytest.xfail( - f"Array API namespace {array_namespace} does not support searchsorted yet." - ) check(name, estimator, array_namespace, device=device, dtype=dtype) From 68d0b9cbbb522f0214eb27eaf61fda3270c2dedf Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 28 Sep 2023 09:42:09 +0200 Subject: [PATCH 9/9] Restore local variable to avoid complex multi-line formatting --- sklearn/decomposition/_pca.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index bd9b5325e38c8..4c2674c44c9c9 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -589,14 +589,9 @@ def _fit_full(self, X, n_components): ) else: explained_variance_ratio_np = explained_variance_ratio_ - n_components = ( - np.searchsorted( - stable_cumsum(explained_variance_ratio_np), - n_components, - side="right", - ) - + 1 - ) + ratio_cumsum = stable_cumsum(explained_variance_ratio_np) + n_components = np.searchsorted(ratio_cumsum, n_components, side="right") + 1 + # Compute noise covariance using Probabilistic PCA model # The sigma2 maximum likelihood (cf. eq. 12.46) if n_components < min(n_features, n_samples):