Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7d5c4a8
compute weight sum manually to reduce rounding errors
jeremiedbb Mar 25, 2021
2057369
use weights shape
jeremiedbb Mar 26, 2021
4d6ba86
avoid catastrophic cancellation
jeremiedbb Mar 29, 2021
3947456
cln
jeremiedbb Mar 29, 2021
eb52ff5
add what's new
jeremiedbb Mar 29, 2021
5c06cc0
Merge branch 'master' into improve-sparse-variance-precision
jeremiedbb Mar 29, 2021
fce4257
improve readability, remove 1 unecessary array
jeremiedbb Mar 30, 2021
7b43f52
add comments
jeremiedbb Mar 30, 2021
89ef3f4
improve constant feature detection for standard scaler
jeremiedbb Mar 30, 2021
7834615
cln
jeremiedbb Mar 30, 2021
cc40a0c
add ref comment on the correction term
jeremiedbb Mar 30, 2021
88def6e
Merge branch 'master' into fix-constant-feature-detection
jeremiedbb Mar 31, 2021
fcb4b57
float64 accumulator for sparse; fix dense float32 upcast
jeremiedbb Mar 31, 2021
d47f643
add test
jeremiedbb Mar 31, 2021
063e105
corrected 2 pass alg for dense
jeremiedbb Mar 31, 2021
9c50e0d
what's new
jeremiedbb Mar 31, 2021
a944b5f
cln
jeremiedbb Apr 1, 2021
0726fb8
fallback to old less precise code when numpy < 1.6
jeremiedbb Apr 2, 2021
b77ed18
Test for different values of n_samples
ogrisel Apr 6, 2021
ca7ab15
Even stronger test
ogrisel Apr 6, 2021
bb1c728
No need to test constant features many times + better mask name
ogrisel Apr 6, 2021
2ae7b23
adress comments + parametrize mean
jeremiedbb Apr 7, 2021
34e22e5
cln
jeremiedbb Apr 7, 2021
d47ab9c
check var is small when detected as cste
jeremiedbb Apr 7, 2021
10a8ac6
fix representable mask + avoid inf var
jeremiedbb Apr 7, 2021
b0f590c
address comments
jeremiedbb Apr 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ Changelog
very large values. This problem happens in particular when using a scaler on
sparse data with a constant column with sample weights, in which case
centering is typically disabled. :pr:`19527` by :user:`Oliver Grisel
<ogrisel>` and :user:`Maria Telenczuk <maikia>`.
<ogrisel>` and :user:`Maria Telenczuk <maikia>` and :pr:`19788` by
:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |Fix| :meth:`preprocessing.StandardScaler.inverse_transform` now
correctly handles integer dtypes. :pr:`19356` by :user:`makoeppel`.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ..base import (BaseEstimator, ClassifierMixin, RegressorMixin,
MultiOutputMixin)
from ..preprocessing._data import _is_constant_feature
from ..utils import check_array
from ..utils.validation import FLOAT_DTYPES
from ..utils.validation import _deprecate_positional_args
Expand All @@ -39,7 +40,6 @@
from ..utils._seq_dataset import ArrayDataset32, CSRDataset32
from ..utils._seq_dataset import ArrayDataset64, CSRDataset64
from ..utils.validation import check_is_fitted, _check_sample_weight

from ..utils.fixes import delayed

# TODO: bayesian_ridge_regression and bayesian_regression_ard
Expand Down Expand Up @@ -250,8 +250,8 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
X_var = X_var.astype(X.dtype, copy=False)
# Detect constant features on the computed variance, before taking
# the np.sqrt. Otherwise constant features cannot be detected with
# sample_weights.
constant_mask = X_var < 10 * np.finfo(X.dtype).eps
# sample weights.
constant_mask = _is_constant_feature(X_var, X_offset, X.shape[0])
X_var *= X.shape[0]
X_scale = np.sqrt(X_var, out=X_var)
X_scale[constant_mask] = 1.
Expand Down
19 changes: 18 additions & 1 deletion sklearn/preprocessing/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@
]


def _is_constant_feature(var, mean, n_samples):
"""Detect if a feature is indistinguishable from a constant feature.

The detection is based on its computed variance and on the theoretical
error bounds of the '2 pass algorithm' for variance computation.

See "Algorithms for computing the sample variance: analysis and
recommendations", by Chan, Golub, and LeVeque.
"""
# In scikit-learn, variance is always computed using float64 accumulators.
eps = np.finfo(np.float64).eps

upper_bound = n_samples * eps * var + (n_samples * mean * eps)**2
return var <= upper_bound


def _handle_zeros_in_scale(scale, copy=True, constant_mask=None):
"""Set scales of near constant features to 1.

Expand Down Expand Up @@ -863,7 +879,8 @@ def partial_fit(self, X, y=None, sample_weight=None):
if self.with_std:
# Extract the list of near constant features on the raw variances,
# before taking the square root.
constant_mask = self.var_ < 10 * np.finfo(X.dtype).eps
constant_mask = _is_constant_feature(
self.var_, self.mean_, self.n_samples_seen_)
self.scale_ = _handle_zeros_in_scale(
np.sqrt(self.var_), copy=False, constant_mask=constant_mask)
else:
Expand Down
63 changes: 56 additions & 7 deletions sklearn/preprocessing/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,6 @@ def test_standard_scaler_dtype(add_sample_weight, sparse_constructor):
@pytest.mark.parametrize("constant", [0, 1., 100.])
def test_standard_scaler_constant_features(
scaler, add_sample_weight, sparse_constructor, dtype, constant):
if (isinstance(scaler, StandardScaler)
and constant > 1
and sparse_constructor is not np.asarray
and add_sample_weight):
# https://github.com/scikit-learn/scikit-learn/issues/19546
pytest.xfail("Computation of weighted variance is numerically unstable"
" for sparse data. See: #19546.")

if isinstance(scaler, RobustScaler) and add_sample_weight:
pytest.skip(f"{scaler.__class__.__name__} does not yet support"
Expand Down Expand Up @@ -269,6 +262,62 @@ def test_standard_scaler_constant_features(
assert_allclose(X_scaled_2, X_scaled_2)


@pytest.mark.parametrize("n_samples", [10, 100, 10_000])
@pytest.mark.parametrize("average", [1e-10, 1, 1e10])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("array_constructor",
[np.asarray, sparse.csc_matrix, sparse.csr_matrix])
def test_standard_scaler_near_constant_features(n_samples, array_constructor,
average, dtype):
# Check that when the variance is too small (var << mean**2) the feature
# is considered constant and not scaled.

scale_min, scale_max = -30, 19
scales = np.array([10**i for i in range(scale_min, scale_max + 1)],
dtype=dtype)

n_features = scales.shape[0]
X = np.empty((n_samples, n_features), dtype=dtype)
# Make a dataset of known var = scales**2 and mean = average
X[:n_samples//2, :] = average + scales
X[n_samples//2:, :] = average - scales
X_array = array_constructor(X)

scaler = StandardScaler(with_mean=False).fit(X_array)

# StandardScaler uses float64 accumulators even if the data has a float32
# dtype.
eps = np.finfo(np.float64).eps

# if var < bound = N.eps.var + N².eps².mean², the feature is considered
# constant and the scale_ attribute is set to 1.
bounds = n_samples * eps * scales**2 + n_samples**2 * eps**2 * average**2
within_bounds = scales**2 <= bounds

# Check that scale_min is small enough to have some scales below the
# bound and therefore detected as constant:
assert np.any(within_bounds)

# Check that such features are actually treated as constant by the scaler:
assert all(scaler.var_[within_bounds] <= bounds[within_bounds])
assert_allclose(scaler.scale_[within_bounds], 1.)

# Depending the on the dtype of X, some features might not actually be
# representable as non constant for small scales (even if above the
# precision bound of the float64 variance estimate). Such feature should
# be correctly detected as constants with 0 variance by StandardScaler.
representable_diff = X[0, :] - X[-1, :] != 0
assert_allclose(scaler.var_[np.logical_not(representable_diff)], 0)
assert_allclose(scaler.scale_[np.logical_not(representable_diff)], 1)

# The other features are scaled and scale_ is equal to sqrt(var_) assuming
# that scales are large enough for average + scale and average - scale to
# be distinct in X (depending on X's dtype).
common_mask = np.logical_and(scales**2 > bounds, representable_diff)
assert_allclose(scaler.scale_[common_mask],
np.sqrt(scaler.var_)[common_mask])


def test_scale_1d():
# 1-d inputs
X_list = [1., 3., 5., 0.]
Expand Down
42 changes: 35 additions & 7 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from . import check_random_state
from ._logistic_sigmoid import _log_logistic_sigmoid
from .fixes import np_version, parse_version
from .sparsefuncs_fast import csr_row_norms
from .validation import check_array
from .validation import _deprecate_positional_args
Expand Down Expand Up @@ -767,10 +768,17 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count,
# updated = the aggregated stats
last_sum = last_mean * last_sample_count
if sample_weight is not None:
new_sum = _safe_accumulator_op(np.nansum, X * sample_weight[:, None],
axis=0)
new_sample_count = np.sum(sample_weight[:, None] * (~np.isnan(X)),
axis=0)
if np_version >= parse_version("1.16.6"):
# equivalent to np.nansum(X * sample_weight, axis=0)
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
# dtype arg of np.matmul only exists since version 1.16
new_sum = _safe_accumulator_op(
np.matmul, sample_weight, np.where(np.isnan(X), 0, X))
else:
new_sum = _safe_accumulator_op(
np.nansum, X * sample_weight[:, None], axis=0)
new_sample_count = _safe_accumulator_op(
np.sum, sample_weight[:, None] * (~np.isnan(X)), axis=0)
else:
new_sum = _safe_accumulator_op(np.nansum, X, axis=0)
new_sample_count = np.sum(~np.isnan(X), axis=0)
Expand All @@ -784,10 +792,30 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count,
else:
T = new_sum / new_sample_count
if sample_weight is not None:
new_unnormalized_variance = np.nansum(sample_weight[:, None] *
(X - T)**2, axis=0)
if np_version >= parse_version("1.16.6"):
# equivalent to np.nansum((X-T)**2 * sample_weight, axis=0)
# safer because np.float64(X*W) != np.float64(X)*np.float64(W)
# dtype arg of np.matmul only exists since version 1.16
new_unnormalized_variance = _safe_accumulator_op(
np.matmul, sample_weight,
np.where(np.isnan(X), 0, (X - T)**2))
correction = _safe_accumulator_op(
np.matmul, sample_weight, np.where(np.isnan(X), 0, X - T))
else:
new_unnormalized_variance = _safe_accumulator_op(
np.nansum, (X - T)**2 * sample_weight[:, None], axis=0)
correction = _safe_accumulator_op(
np.nansum, (X - T) * sample_weight[:, None], axis=0)
else:
new_unnormalized_variance = np.nansum((X - T)**2, axis=0)
new_unnormalized_variance = _safe_accumulator_op(
np.nansum, (X - T)**2, axis=0)
correction = _safe_accumulator_op(np.nansum, X - T, axis=0)

# correction term of the corrected 2 pass algorithm.
# See "Algorithms for computing the sample variance: analysis
# and recommendations", by Chan, Golub, and LeVeque.
new_unnormalized_variance -= correction**2 / new_sample_count

last_unnormalized_variance = last_variance * last_sample_count

with np.errstate(divide='ignore', invalid='ignore'):
Expand Down
90 changes: 52 additions & 38 deletions sklearn/utils/sparsefuncs_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data,
def csr_mean_variance_axis0(X, weights=None, return_sum_weights=False):
"""Compute mean and variance along axis 0 on a CSR matrix

Uses a np.float64 accumulator.

Parameters
----------
X : CSR sparse matrix, shape (n_samples, n_features)
Expand Down Expand Up @@ -109,25 +111,18 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
np.npy_intp i
unsigned long long row_ind
integral col_ind
floating diff
np.float64_t diff
# means[j] contains the mean of feature j
np.ndarray[floating, ndim=1] means
np.ndarray[np.float64_t, ndim=1] means = np.zeros(n_features)
# variances[j] contains the variance of feature j
np.ndarray[floating, ndim=1] variances

if floating is float:
dtype = np.float32
else:
dtype = np.float64
np.ndarray[np.float64_t, ndim=1] variances = np.zeros(n_features)

means = np.zeros(n_features, dtype=dtype)
variances = np.zeros_like(means, dtype=dtype)

cdef:
np.ndarray[floating, ndim=1] sum_weights = np.full(
fill_value=np.sum(weights), shape=n_features, dtype=dtype)
np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros(
shape=n_features, dtype=dtype)
np.ndarray[np.float64_t, ndim=1] sum_weights = np.full(
fill_value=np.sum(weights, dtype=np.float64), shape=n_features)
np.ndarray[np.float64_t, ndim=1] sum_weights_nz = np.zeros(
shape=n_features)
np.ndarray[np.float64_t, ndim=1] correction = np.zeros(
shape=n_features)

np.ndarray[np.uint64_t, ndim=1] counts = np.full(
fill_value=weights.shape[0], shape=n_features, dtype=np.uint64)
Expand All @@ -138,7 +133,7 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
for i in range(X_indptr[row_ind], X_indptr[row_ind + 1]):
col_ind = X_indices[i]
if not isnan(X_data[i]):
means[col_ind] += (X_data[i] * weights[row_ind])
means[col_ind] += <np.float64_t>(X_data[i]) * weights[row_ind]
# sum of weights where X[:, col_ind] is non-zero
sum_weights_nz[col_ind] += weights[row_ind]
# number of non-zero elements of X[:, col_ind]
Expand All @@ -157,21 +152,35 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
col_ind = X_indices[i]
if not isnan(X_data[i]):
diff = X_data[i] - means[col_ind]
# correction term of the corrected 2 pass algorithm.
# See "Algorithms for computing the sample variance: analysis
# and recommendations", by Chan, Golub, and LeVeque.
correction[col_ind] += diff * weights[row_ind]
variances[col_ind] += diff * diff * weights[row_ind]

for i in range(n_features):
if counts[i] != counts_nz[i]:
correction[i] -= (sum_weights[i] - sum_weights_nz[i]) * means[i]
correction[i] = correction[i]**2 / sum_weights[i]
if counts[i] != counts_nz[i]:
# only compute it when it's guaranteed to be non-zero to avoid
# catastrophic cancellation.
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
variances[i] /= sum_weights[i]
variances[i] = (variances[i] - correction[i]) / sum_weights[i]

return means, variances, sum_weights
if floating is float:
return (np.array(means, dtype=np.float32),
np.array(variances, dtype=np.float32),
np.array(sum_weights, dtype=np.float32))
else:
return means, variances, sum_weights


def csc_mean_variance_axis0(X, weights=None, return_sum_weights=False):
"""Compute mean and variance along axis 0 on a CSC matrix

Uses a np.float64 accumulator.

Parameters
----------
X : CSC sparse matrix, shape (n_samples, n_features)
Expand Down Expand Up @@ -224,25 +233,18 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
np.npy_intp i
unsigned long long col_ind
integral row_ind
floating diff
np.float64_t diff
# means[j] contains the mean of feature j
np.ndarray[floating, ndim=1] means
np.ndarray[np.float64_t, ndim=1] means = np.zeros(n_features)
# variances[j] contains the variance of feature j
np.ndarray[floating, ndim=1] variances

if floating is float:
dtype = np.float32
else:
dtype = np.float64
np.ndarray[np.float64_t, ndim=1] variances = np.zeros(n_features)

means = np.zeros(n_features, dtype=dtype)
variances = np.zeros_like(means, dtype=dtype)

cdef:
np.ndarray[floating, ndim=1] sum_weights = np.full(
fill_value=np.sum(weights), shape=n_features, dtype=dtype)
np.ndarray[floating, ndim=1] sum_weights_nz = np.zeros(
shape=n_features, dtype=dtype)
np.ndarray[np.float64_t, ndim=1] sum_weights = np.full(
fill_value=np.sum(weights, dtype=np.float64), shape=n_features)
np.ndarray[np.float64_t, ndim=1] sum_weights_nz = np.zeros(
shape=n_features)
np.ndarray[np.float64_t, ndim=1] correction = np.zeros(
shape=n_features)

np.ndarray[np.uint64_t, ndim=1] counts = np.full(
fill_value=weights.shape[0], shape=n_features, dtype=np.uint64)
Expand All @@ -253,7 +255,7 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
for i in range(X_indptr[col_ind], X_indptr[col_ind + 1]):
row_ind = X_indices[i]
if not isnan(X_data[i]):
means[col_ind] += (X_data[i] * weights[row_ind])
means[col_ind] += <np.float64_t>(X_data[i]) * weights[row_ind]
# sum of weights where X[:, col_ind] is non-zero
sum_weights_nz[col_ind] += weights[row_ind]
# number of non-zero elements of X[:, col_ind]
Expand All @@ -272,16 +274,28 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
row_ind = X_indices[i]
if not isnan(X_data[i]):
diff = X_data[i] - means[col_ind]
# correction term of the corrected 2 pass algorithm.
# See "Algorithms for computing the sample variance: analysis
# and recommendations", by Chan, Golub, and LeVeque.
correction[col_ind] += diff * weights[row_ind]
variances[col_ind] += diff * diff * weights[row_ind]

for i in range(n_features):
if counts[i] != counts_nz[i]:
correction[i] -= (sum_weights[i] - sum_weights_nz[i]) * means[i]
correction[i] = correction[i]**2 / sum_weights[i]
if counts[i] != counts_nz[i]:
# only compute it when it's guaranteed to be non-zero to avoid
# catastrophic cancellation.
variances[i] += (sum_weights[i] - sum_weights_nz[i]) * means[i]**2
variances[i] /= sum_weights[i]
variances[i] = (variances[i] - correction[i]) / sum_weights[i]

return means, variances, sum_weights
if floating is float:
return (np.array(means, dtype=np.float32),
np.array(variances, dtype=np.float32),
np.array(sum_weights, dtype=np.float32))
else:
return means, variances, sum_weights


def incr_mean_variance_axis0(X, last_mean, last_var, last_n, weights=None):
Expand Down