diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 2b108d2f0e197..859b284b012f3 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -266,7 +266,8 @@ Changelog very large values. This problem happens in particular when using a scaler on sparse data with a constant column with sample weights, in which case centering is typically disabled. :pr:`19527` by :user:`Oliver Grisel - ` and :user:`Maria Telenczuk `. + ` and :user:`Maria Telenczuk ` and :pr:`19788` by + :user:`Jérémie du Boisberranger `. - |Fix| :meth:`preprocessing.StandardScaler.inverse_transform` now correctly handles integer dtypes. :pr:`19356` by :user:`makoeppel`. diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 1842620dfa105..d898648cef083 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -28,6 +28,7 @@ from ..base import (BaseEstimator, ClassifierMixin, RegressorMixin, MultiOutputMixin) +from ..preprocessing._data import _is_constant_feature from ..utils import check_array from ..utils.validation import FLOAT_DTYPES from ..utils.validation import _deprecate_positional_args @@ -39,7 +40,6 @@ from ..utils._seq_dataset import ArrayDataset32, CSRDataset32 from ..utils._seq_dataset import ArrayDataset64, CSRDataset64 from ..utils.validation import check_is_fitted, _check_sample_weight - from ..utils.fixes import delayed # TODO: bayesian_ridge_regression and bayesian_regression_ard @@ -250,8 +250,8 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True, X_var = X_var.astype(X.dtype, copy=False) # Detect constant features on the computed variance, before taking # the np.sqrt. Otherwise constant features cannot be detected with - # sample_weights. - constant_mask = X_var < 10 * np.finfo(X.dtype).eps + # sample weights. + constant_mask = _is_constant_feature(X_var, X_offset, X.shape[0]) X_var *= X.shape[0] X_scale = np.sqrt(X_var, out=X_var) X_scale[constant_mask] = 1. diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index 6191fb2fd8bcd..80cb132174328 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -57,6 +57,22 @@ ] +def _is_constant_feature(var, mean, n_samples): + """Detect if a feature is indistinguishable from a constant feature. + + The detection is based on its computed variance and on the theoretical + error bounds of the '2 pass algorithm' for variance computation. + + See "Algorithms for computing the sample variance: analysis and + recommendations", by Chan, Golub, and LeVeque. + """ + # In scikit-learn, variance is always computed using float64 accumulators. + eps = np.finfo(np.float64).eps + + upper_bound = n_samples * eps * var + (n_samples * mean * eps)**2 + return var <= upper_bound + + def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): """Set scales of near constant features to 1. @@ -863,7 +879,8 @@ def partial_fit(self, X, y=None, sample_weight=None): if self.with_std: # Extract the list of near constant features on the raw variances, # before taking the square root. - constant_mask = self.var_ < 10 * np.finfo(X.dtype).eps + constant_mask = _is_constant_feature( + self.var_, self.mean_, self.n_samples_seen_) self.scale_ = _handle_zeros_in_scale( np.sqrt(self.var_), copy=False, constant_mask=constant_mask) else: diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 5557562283850..45d967d5f39a2 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -224,13 +224,6 @@ def test_standard_scaler_dtype(add_sample_weight, sparse_constructor): @pytest.mark.parametrize("constant", [0, 1., 100.]) def test_standard_scaler_constant_features( scaler, add_sample_weight, sparse_constructor, dtype, constant): - if (isinstance(scaler, StandardScaler) - and constant > 1 - and sparse_constructor is not np.asarray - and add_sample_weight): - # https://github.com/scikit-learn/scikit-learn/issues/19546 - pytest.xfail("Computation of weighted variance is numerically unstable" - " for sparse data. See: #19546.") if isinstance(scaler, RobustScaler) and add_sample_weight: pytest.skip(f"{scaler.__class__.__name__} does not yet support" @@ -269,6 +262,62 @@ def test_standard_scaler_constant_features( assert_allclose(X_scaled_2, X_scaled_2) +@pytest.mark.parametrize("n_samples", [10, 100, 10_000]) +@pytest.mark.parametrize("average", [1e-10, 1, 1e10]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("array_constructor", + [np.asarray, sparse.csc_matrix, sparse.csr_matrix]) +def test_standard_scaler_near_constant_features(n_samples, array_constructor, + average, dtype): + # Check that when the variance is too small (var << mean**2) the feature + # is considered constant and not scaled. + + scale_min, scale_max = -30, 19 + scales = np.array([10**i for i in range(scale_min, scale_max + 1)], + dtype=dtype) + + n_features = scales.shape[0] + X = np.empty((n_samples, n_features), dtype=dtype) + # Make a dataset of known var = scales**2 and mean = average + X[:n_samples//2, :] = average + scales + X[n_samples//2:, :] = average - scales + X_array = array_constructor(X) + + scaler = StandardScaler(with_mean=False).fit(X_array) + + # StandardScaler uses float64 accumulators even if the data has a float32 + # dtype. + eps = np.finfo(np.float64).eps + + # if var < bound = N.eps.var + N².eps².mean², the feature is considered + # constant and the scale_ attribute is set to 1. + bounds = n_samples * eps * scales**2 + n_samples**2 * eps**2 * average**2 + within_bounds = scales**2 <= bounds + + # Check that scale_min is small enough to have some scales below the + # bound and therefore detected as constant: + assert np.any(within_bounds) + + # Check that such features are actually treated as constant by the scaler: + assert all(scaler.var_[within_bounds] <= bounds[within_bounds]) + assert_allclose(scaler.scale_[within_bounds], 1.) + + # Depending the on the dtype of X, some features might not actually be + # representable as non constant for small scales (even if above the + # precision bound of the float64 variance estimate). Such feature should + # be correctly detected as constants with 0 variance by StandardScaler. + representable_diff = X[0, :] - X[-1, :] != 0 + assert_allclose(scaler.var_[np.logical_not(representable_diff)], 0) + assert_allclose(scaler.scale_[np.logical_not(representable_diff)], 1) + + # The other features are scaled and scale_ is equal to sqrt(var_) assuming + # that scales are large enough for average + scale and average - scale to + # be distinct in X (depending on X's dtype). + common_mask = np.logical_and(scales**2 > bounds, representable_diff) + assert_allclose(scaler.scale_[common_mask], + np.sqrt(scaler.var_)[common_mask]) + + def test_scale_1d(): # 1-d inputs X_list = [1., 3., 5., 0.] diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 42a014dcd8ade..add8c5883a751 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -18,6 +18,7 @@ from . import check_random_state from ._logistic_sigmoid import _log_logistic_sigmoid +from .fixes import np_version, parse_version from .sparsefuncs_fast import csr_row_norms from .validation import check_array from .validation import _deprecate_positional_args @@ -767,10 +768,17 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count, # updated = the aggregated stats last_sum = last_mean * last_sample_count if sample_weight is not None: - new_sum = _safe_accumulator_op(np.nansum, X * sample_weight[:, None], - axis=0) - new_sample_count = np.sum(sample_weight[:, None] * (~np.isnan(X)), - axis=0) + if np_version >= parse_version("1.16.6"): + # equivalent to np.nansum(X * sample_weight, axis=0) + # safer because np.float64(X*W) != np.float64(X)*np.float64(W) + # dtype arg of np.matmul only exists since version 1.16 + new_sum = _safe_accumulator_op( + np.matmul, sample_weight, np.where(np.isnan(X), 0, X)) + else: + new_sum = _safe_accumulator_op( + np.nansum, X * sample_weight[:, None], axis=0) + new_sample_count = _safe_accumulator_op( + np.sum, sample_weight[:, None] * (~np.isnan(X)), axis=0) else: new_sum = _safe_accumulator_op(np.nansum, X, axis=0) new_sample_count = np.sum(~np.isnan(X), axis=0) @@ -784,10 +792,30 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count, else: T = new_sum / new_sample_count if sample_weight is not None: - new_unnormalized_variance = np.nansum(sample_weight[:, None] * - (X - T)**2, axis=0) + if np_version >= parse_version("1.16.6"): + # equivalent to np.nansum((X-T)**2 * sample_weight, axis=0) + # safer because np.float64(X*W) != np.float64(X)*np.float64(W) + # dtype arg of np.matmul only exists since version 1.16 + new_unnormalized_variance = _safe_accumulator_op( + np.matmul, sample_weight, + np.where(np.isnan(X), 0, (X - T)**2)) + correction = _safe_accumulator_op( + np.matmul, sample_weight, np.where(np.isnan(X), 0, X - T)) + else: + new_unnormalized_variance = _safe_accumulator_op( + np.nansum, (X - T)**2 * sample_weight[:, None], axis=0) + correction = _safe_accumulator_op( + np.nansum, (X - T) * sample_weight[:, None], axis=0) else: - new_unnormalized_variance = np.nansum((X - T)**2, axis=0) + new_unnormalized_variance = _safe_accumulator_op( + np.nansum, (X - T)**2, axis=0) + correction = _safe_accumulator_op(np.nansum, X - T, axis=0) + + # correction term of the corrected 2 pass algorithm. + # See "Algorithms for computing the sample variance: analysis + # and recommendations", by Chan, Golub, and LeVeque. + new_unnormalized_variance -= correction**2 / new_sample_count + last_unnormalized_variance = last_variance * last_sample_count with np.errstate(divide='ignore', invalid='ignore'): diff --git a/sklearn/utils/sparsefuncs_fast.pyx b/sklearn/utils/sparsefuncs_fast.pyx index 4a84c03eff86b..09677600cbbe4 100644 --- a/sklearn/utils/sparsefuncs_fast.pyx +++ b/sklearn/utils/sparsefuncs_fast.pyx @@ -57,6 +57,8 @@ def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data, def csr_mean_variance_axis0(X, weights=None, return_sum_weights=False): """Compute mean and variance along axis 0 on a CSR matrix + Uses a np.float64 accumulator. + Parameters ---------- X : CSR sparse matrix, shape (n_samples, n_features) @@ -109,25 +111,18 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, np.npy_intp i unsigned long long row_ind integral col_ind - floating diff + np.float64_t diff # means[j] contains the mean of feature j - np.ndarray[floating, ndim=1] means + np.ndarray[np.float64_t, ndim=1] means = np.zeros(n_features) # variances[j] contains the variance of feature j - np.ndarray[floating, ndim=1] variances - - if floating is float: - dtype = np.float32 - else: - dtype = np.float64 + np.ndarray[np.float64_t, ndim=1] variances = np.zeros(n_features) - means = np.zeros(n_features, dtype=dtype) - variances = np.zeros_like(means, dtype=dtype) - - cdef: - np.ndarray[floating, ndim=1] sum_weights = np.full( - fill_value=np.sum(weights), shape=n_features, dtype=dtype) - np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros( - shape=n_features, dtype=dtype) + np.ndarray[np.float64_t, ndim=1] sum_weights = np.full( + fill_value=np.sum(weights, dtype=np.float64), shape=n_features) + np.ndarray[np.float64_t, ndim=1] sum_weights_nz = np.zeros( + shape=n_features) + np.ndarray[np.float64_t, ndim=1] correction = np.zeros( + shape=n_features) np.ndarray[np.uint64_t, ndim=1] counts = np.full( fill_value=weights.shape[0], shape=n_features, dtype=np.uint64) @@ -138,7 +133,7 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, for i in range(X_indptr[row_ind], X_indptr[row_ind + 1]): col_ind = X_indices[i] if not isnan(X_data[i]): - means[col_ind] += (X_data[i] * weights[row_ind]) + means[col_ind] += (X_data[i]) * weights[row_ind] # sum of weights where X[:, col_ind] is non-zero sum_weights_nz[col_ind] += weights[row_ind] # number of non-zero elements of X[:, col_ind] @@ -157,21 +152,35 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, col_ind = X_indices[i] if not isnan(X_data[i]): diff = X_data[i] - means[col_ind] + # correction term of the corrected 2 pass algorithm. + # See "Algorithms for computing the sample variance: analysis + # and recommendations", by Chan, Golub, and LeVeque. + correction[col_ind] += diff * weights[row_ind] variances[col_ind] += diff * diff * weights[row_ind] for i in range(n_features): + if counts[i] != counts_nz[i]: + correction[i] -= (sum_weights[i] - sum_weights_nz[i]) * means[i] + correction[i] = correction[i]**2 / sum_weights[i] if counts[i] != counts_nz[i]: # only compute it when it's guaranteed to be non-zero to avoid # catastrophic cancellation. variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2 - variances[i] /= sum_weights[i] + variances[i] = (variances[i] - correction[i]) / sum_weights[i] - return means, variances, sum_weights + if floating is float: + return (np.array(means, dtype=np.float32), + np.array(variances, dtype=np.float32), + np.array(sum_weights, dtype=np.float32)) + else: + return means, variances, sum_weights def csc_mean_variance_axis0(X, weights=None, return_sum_weights=False): """Compute mean and variance along axis 0 on a CSC matrix + Uses a np.float64 accumulator. + Parameters ---------- X : CSC sparse matrix, shape (n_samples, n_features) @@ -224,25 +233,18 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, np.npy_intp i unsigned long long col_ind integral row_ind - floating diff + np.float64_t diff # means[j] contains the mean of feature j - np.ndarray[floating, ndim=1] means + np.ndarray[np.float64_t, ndim=1] means = np.zeros(n_features) # variances[j] contains the variance of feature j - np.ndarray[floating, ndim=1] variances - - if floating is float: - dtype = np.float32 - else: - dtype = np.float64 + np.ndarray[np.float64_t, ndim=1] variances = np.zeros(n_features) - means = np.zeros(n_features, dtype=dtype) - variances = np.zeros_like(means, dtype=dtype) - - cdef: - np.ndarray[floating, ndim=1] sum_weights = np.full( - fill_value=np.sum(weights), shape=n_features, dtype=dtype) - np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros( - shape=n_features, dtype=dtype) + np.ndarray[np.float64_t, ndim=1] sum_weights = np.full( + fill_value=np.sum(weights, dtype=np.float64), shape=n_features) + np.ndarray[np.float64_t, ndim=1] sum_weights_nz = np.zeros( + shape=n_features) + np.ndarray[np.float64_t, ndim=1] correction = np.zeros( + shape=n_features) np.ndarray[np.uint64_t, ndim=1] counts = np.full( fill_value=weights.shape[0], shape=n_features, dtype=np.uint64) @@ -253,7 +255,7 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, for i in range(X_indptr[col_ind], X_indptr[col_ind + 1]): row_ind = X_indices[i] if not isnan(X_data[i]): - means[col_ind] += (X_data[i] * weights[row_ind]) + means[col_ind] += (X_data[i]) * weights[row_ind] # sum of weights where X[:, col_ind] is non-zero sum_weights_nz[col_ind] += weights[row_ind] # number of non-zero elements of X[:, col_ind] @@ -272,16 +274,28 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, row_ind = X_indices[i] if not isnan(X_data[i]): diff = X_data[i] - means[col_ind] + # correction term of the corrected 2 pass algorithm. + # See "Algorithms for computing the sample variance: analysis + # and recommendations", by Chan, Golub, and LeVeque. + correction[col_ind] += diff * weights[row_ind] variances[col_ind] += diff * diff * weights[row_ind] for i in range(n_features): + if counts[i] != counts_nz[i]: + correction[i] -= (sum_weights[i] - sum_weights_nz[i]) * means[i] + correction[i] = correction[i]**2 / sum_weights[i] if counts[i] != counts_nz[i]: # only compute it when it's guaranteed to be non-zero to avoid # catastrophic cancellation. variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2 - variances[i] /= sum_weights[i] + variances[i] = (variances[i] - correction[i]) / sum_weights[i] - return means, variances, sum_weights + if floating is float: + return (np.array(means, dtype=np.float32), + np.array(variances, dtype=np.float32), + np.array(sum_weights, dtype=np.float32)) + else: + return means, variances, sum_weights def incr_mean_variance_axis0(X, last_mean, last_var, last_n, weights=None):