Skip to content

increment_mean_and_var can now handle NaN values #10618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8cf9dbd
increment_mean_and_var can now handle NaN values
pinakinathc Feb 10, 2018
f7d7381
fixed errors
pinakinathc Feb 10, 2018
e3421e9
removed print statement which were added mistakenly
pinakinathc Feb 11, 2018
f924401
removed unwanted print statements
pinakinathc Feb 11, 2018
c812be9
trying to fix the errors
pinakinathc Feb 11, 2018
4758bc9
added test cases where there is a chance to get a 1D matrix
pinakinathc Feb 12, 2018
df954ac
check if there is a green tick when cases involving NaN is commented out
pinakinathc Feb 12, 2018
75d221b
removing np.count_nonzero() from the code
pinakinathc Feb 12, 2018
a3bb041
removed np.count_non_zero() in test cases
pinakinathc Feb 12, 2018
8db4311
resolved some errors
pinakinathc Feb 12, 2018
1f115f1
remove errors
pinakinathc Feb 12, 2018
566ad15
resolving errors
pinakinathc Feb 12, 2018
499133d
removing errors
pinakinathc Feb 12, 2018
032b058
removed errors
pinakinathc Feb 12, 2018
e6c0521
removed errors and modified test cases
pinakinathc Feb 13, 2018
7810d6e
changes in the code and removed cases for 1D matrix
pinakinathc Feb 15, 2018
be05d16
removed pep8 errors
pinakinathc Feb 15, 2018
87667ab
removed if condition at line +722 of `extmath.py` and changed some ot…
pinakinathc Feb 15, 2018
8e41081
made `last_samples_seen` and `updated_sample_seen` array
pinakinathc Feb 15, 2018
adced1d
modified csr_matrix and csc_matrix to be able to handle NaN values
pinakinathc Feb 18, 2018
c380bc7
Merge branch 'incr-mean-and-var' of https://github.com/pinakinathc/sc…
pinakinathc Feb 25, 2018
a0e25e9
made changes to optimize the code
pinakinathc Mar 11, 2018
e524631
remove pep8 errors
pinakinathc Mar 11, 2018
d8c99c2
Merge branch 'incr-mean-and-var' of https://github.com/pinakinathc/sc…
pinakinathc Mar 11, 2018
dc151a6
corrected csr_mean_variance_axis0 and csc_mean_variance_axis0
pinakinathc Mar 11, 2018
512ca7b
Merge pull request #1 from pinakinathc/sparseMatrix-test
pinakinathc Mar 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sklearn/decomposition/incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,10 @@ def partial_fit(self, X, y=None, check_input=True):

# Update stats - they are 0 if this is the fisrt step
col_mean, col_var, n_total_samples = \
_incremental_mean_and_var(X, last_mean=self.mean_,
last_variance=self.var_,
last_sample_count=self.n_samples_seen_)
_incremental_mean_and_var(
X, last_mean=self.mean_, last_variance=self.var_,
last_sample_count=self.n_samples_seen_ * np.ones(n_features))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be supporting NaNs here. I think maybe we should allow it still to pass in a scalar n_samples_seen_ and _incremental_mean_and_var can broadcast it to n_features wide if appropriate.

n_total_samples = n_total_samples[0]

# Whitening
if self.n_samples_seen_ == 0:
Expand Down
16 changes: 13 additions & 3 deletions sklearn/preprocessing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,13 +619,19 @@ def partial_fit(self, X, y=None):
Ignored
"""
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
warn_on_dtype=True, estimator=self, dtype=FLOAT_DTYPES)
warn_on_dtype=True, estimator=self,
force_all_finite='allow-nan', dtype=FLOAT_DTYPES)

# Even in the case of `with_mean=False`, we update the mean anyway
# This is needed for the incremental computation of the var
# See incr_mean_variance_axis and _incremental_mean_variance_axis

if sparse.issparse(X):
# FIXME: remove this check statement
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
warn_on_dtype=True, estimator=self,
dtype=FLOAT_DTYPES)

if self.with_mean:
raise ValueError(
"Cannot center sparse matrices: pass `with_mean=False` "
Expand All @@ -646,6 +652,9 @@ def partial_fit(self, X, y=None):
self.mean_ = None
self.var_ = None
else:
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
warn_on_dtype=True, estimator=self,
dtype=FLOAT_DTYPES)
# First pass
if not hasattr(self, 'n_samples_seen_'):
self.mean_ = .0
Expand All @@ -656,8 +665,9 @@ def partial_fit(self, X, y=None):
self.var_ = None

self.mean_, self.var_, self.n_samples_seen_ = \
_incremental_mean_and_var(X, self.mean_, self.var_,
self.n_samples_seen_)
_incremental_mean_and_var(
X, self.mean_, self.var_,
self.n_samples_seen_ * np.ones(X.shape[1]))
Copy link
Member

@jnothman jnothman Feb 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to keep n_samples_seen_ for each feature from iteration to iteration. I don't see how this could work atm. And yet, for backwards compatibility, we need to report only a scalar in cases that are not affected by this PR (i.e. where there are no NaNs, or perhaps where n_samples_seen_ is constant even if there were NaNs).

For example, you might compress the updated count to a scalar if not np.any(np.diff(n_samples_seen_))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it. From the changes, I though that self.n_samples_seen_ should always be an array.

And yet, for backwards compatibility, we need to report only a scalar in cases that are not affected by this PR (i.e. where there are no NaNs, or perhaps where n_samples_seen_ is constant even if there were NaNs).

I thought that it would be easier to change only array from now on. Only incremental_pca is affected apart of the StandardScaler and those functions are private so the end-user should not care.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_samples_seen_ is not private IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman @glemaitre Sorry for being inactive for the past week. Shall I keep self.n_sample_seen_ a vector or scaler?
PS: As of now, self.n_sample_seen_ is a vector both in StandardScaler and incremental_pca

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it would be best for backwards compatibility to keep a scalar in the case when there are no NaNs or - for simplicity - in the case when all n_samples_seen are equal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman @glemaitre I'll make them generalised i.e. if all n_samples_seen are equal, it will return a scalar instead of a vector.


if self.with_std:
self.scale_ = _handle_zeros_in_scale(np.sqrt(self.var_))
Expand Down
14 changes: 8 additions & 6 deletions sklearn/preprocessing/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_standard_scaler_1d():
np.zeros_like(n_features))
assert_array_almost_equal(X_scaled.mean(axis=0), .0)
assert_array_almost_equal(X_scaled.std(axis=0), 1.)
assert_equal(scaler.n_samples_seen_, X.shape[0])
assert_array_equal(scaler.n_samples_seen_, X.shape[0])

# check inverse transform
X_scaled_back = scaler.inverse_transform(X_scaled)
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_scaler_2d_arrays():
scaler = StandardScaler()
X_scaled = scaler.fit(X).transform(X, copy=True)
assert_false(np.any(np.isnan(X_scaled)))
assert_equal(scaler.n_samples_seen_, n_samples)
assert_array_equal(scaler.n_samples_seen_, n_samples)

assert_array_almost_equal(X_scaled.mean(axis=0), n_features * [0.0])
assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.])
Expand Down Expand Up @@ -399,7 +399,8 @@ def test_standard_scaler_partial_fit():

assert_array_almost_equal(scaler_batch.mean_, scaler_incr.mean_)
assert_equal(scaler_batch.var_, scaler_incr.var_) # Nones
assert_equal(scaler_batch.n_samples_seen_, scaler_incr.n_samples_seen_)
assert_array_almost_equal(
scaler_batch.n_samples_seen_, scaler_incr.n_samples_seen_)

# Test std after 1 step
batch0 = slice(0, chunk_size)
Expand All @@ -423,10 +424,11 @@ def test_standard_scaler_partial_fit():
assert_correct_incr(i, batch_start=batch.start,
batch_stop=batch.stop, n=n,
chunk_size=chunk_size,
n_samples_seen=scaler_incr.n_samples_seen_)
n_samples_seen=scaler_incr.n_samples_seen_[0])

assert_array_almost_equal(scaler_batch.var_, scaler_incr.var_)
assert_equal(scaler_batch.n_samples_seen_, scaler_incr.n_samples_seen_)
assert_array_almost_equal(scaler_batch.n_samples_seen_,
scaler_incr.n_samples_seen_)


def test_standard_scaler_partial_fit_numerical_stability():
Expand Down Expand Up @@ -515,7 +517,7 @@ def test_standard_scaler_trasform_with_partial_fit():
assert_array_less(zero, scaler_incr.var_ + epsilon) # as less or equal
assert_array_less(zero, scaler_incr.scale_ + epsilon)
# (i+1) because the Scaler has been already fitted
assert_equal((i + 1), scaler_incr.n_samples_seen_)
assert_array_equal((i + 1), scaler_incr.n_samples_seen_)


def test_min_max_scaler_iris():
Expand Down
3 changes: 3 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression',
'RANSACRegressor', 'RadiusNeighborsRegressor',
'RandomForestRegressor', 'Ridge', 'RidgeCV']
ALLOW_NAN = ['StandardScaler']


def _yield_non_meta_checks(name, estimator):
Expand Down Expand Up @@ -1024,6 +1025,8 @@ def check_estimators_nan_inf(name, estimator_orig):
error_string_transform = ("Estimator doesn't check for NaN and inf in"
" transform.")
for X_train in [X_train_nan, X_train_inf]:
if np.any(np.isnan(X_train)) and name in ALLOW_NAN:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as: #10437 (comment)

continue
# catch deprecation warnings
with ignore_warnings(category=(DeprecationWarning, FutureWarning)):
estimator = clone(estimator_orig)
Expand Down
41 changes: 27 additions & 14 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def make_nonnegative(X, min_value=0):


def _incremental_mean_and_var(X, last_mean=.0, last_variance=None,
last_sample_count=0):
last_sample_count=0, ignore_nan=False):
"""Calculate mean update and a Youngs and Cramer variance update.

last_mean and last_variance are statistics computed at the last step by the
Expand All @@ -664,7 +664,7 @@ def _incremental_mean_and_var(X, last_mean=.0, last_variance=None,

last_variance : array-like, shape: (n_features,)

last_sample_count : int
last_sample_count : array-like, shape: (n_features,)

Returns
-------
Expand All @@ -673,7 +673,7 @@ def _incremental_mean_and_var(X, last_mean=.0, last_variance=None,
updated_variance : array, shape (n_features,)
If None, only mean is computed

updated_sample_count : int
updated_sample_count : array shape (n_features,)

References
----------
Expand All @@ -689,28 +689,41 @@ def _incremental_mean_and_var(X, last_mean=.0, last_variance=None,
# new = the current increment
# updated = the aggregated stats
last_sum = last_mean * last_sample_count
new_sum = X.sum(axis=0)
sum_func = np.nansum if ignore_nan else np.sum
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should start with if not isnan(X.sum()): ignore_nan = False, and then use fast paths that don't involve triplicating the memory like ~isnan(new_sum) does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the lines in this function needs to be modified to be able to ignore NaN values. Now, if we do not want to encounter isnan types of executions are expensive, then there are 2 options:

  • create a separate part of the code which computes without ignoring NaN and a separate part of the code which computes ignoring NaN values
  • for each line of the code which includes isnan function, check if ignore_nan is true or false and write the code accordingly. Like: sumvar = np.nansum if ignore_nan else np.sum

new_sum = sum_func(X, axis=0)
new_sum[np.isnan(new_sum)] = 0

new_sample_count = X.shape[0]
new_sample_count = np.sum(~np.isnan(X), axis=0)
updated_sample_count = last_sample_count + new_sample_count

updated_mean = (last_sum + new_sum) / updated_sample_count
with np.errstate(divide="ignore"): # as division by 0 might happen
updated_mean = (last_sum + new_sum) / updated_sample_count
updated_mean[np.logical_not(updated_sample_count)] = 0

if last_variance is None:
updated_variance = None
else:
new_unnormalized_variance = X.var(axis=0) * new_sample_count
if last_sample_count == 0: # Avoid division by 0
updated_unnormalized_variance = new_unnormalized_variance
else:
var_func = np.nanvar if ignore_nan else np.var
new_unnormalized_variance = var_func(X, axis=0)
new_unnormalized_variance[~np.isfinite(new_unnormalized_variance)] = 0
new_unnormalized_variance *= new_sample_count
last_unnormalized_variance = last_variance * last_sample_count
with np.errstate(divide="ignore"):
last_over_new_count = last_sample_count / new_sample_count
last_unnormalized_variance = last_variance * last_sample_count
updated_unnormalized_variance = (
last_unnormalized_variance +
new_unnormalized_variance +
last_over_new_count / updated_sample_count *
(last_sum / last_over_new_count - new_sum) ** 2)
updated_variance = updated_unnormalized_variance / updated_sample_count
# updated_unnormalized_variance can be both NaN or Inf
updated_unnormalized_variance[~np.isfinite(
updated_unnormalized_variance)] = 0
updated_unnormalized_variance += (last_unnormalized_variance +
new_unnormalized_variance)

with np.errstate(divide="ignore"):
updated_variance = (updated_unnormalized_variance /
updated_sample_count)
# As division by Zero might happen
updated_variance[np.logical_not(updated_sample_count)] = 0

return updated_mean, updated_variance, updated_sample_count

Expand Down
67 changes: 51 additions & 16 deletions sklearn/utils/sparsefuncs_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import numpy as np
import scipy.sparse as sp
cimport cython
from cython cimport floating
from numpy.math cimport isnan
import warnings

np.import_array()

Expand Down Expand Up @@ -54,7 +56,7 @@ def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data,
return norms


def csr_mean_variance_axis0(X):
def csr_mean_variance_axis0(X, ignore_nan=True):
"""Compute mean and variance along axis 0 on a CSR matrix

Parameters
Expand All @@ -74,12 +76,13 @@ def csr_mean_variance_axis0(X):
"""
if X.dtype != np.float32:
X = X.astype(np.float64)
return _csr_mean_variance_axis0(X.data, X.shape, X.indices)
return _csr_mean_variance_axis0(X.data, X.shape, X.indices, ignore_nan)


def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
shape,
np.ndarray[int, ndim=1] X_indices):
np.ndarray[int, ndim=1] X_indices,
ignore_nan=True):
# Implement the function here since variables using fused types
# cannot be declared directly and can only be passed as function arguments
cdef unsigned int n_samples = shape[0]
Expand All @@ -94,6 +97,8 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
cdef np.ndarray[floating, ndim=1] means
# variances[j] contains the variance of feature j
cdef np.ndarray[floating, ndim=1] variances
# n_samples_feat[j] contains the n_samples of feature j
cdef np.ndarray[floating, ndim=1] n_samples_feat

if floating is float:
dtype = np.float32
Expand All @@ -102,31 +107,44 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,

means = np.zeros(n_features, dtype=dtype)
variances = np.zeros_like(means, dtype=dtype)
n_samples_feat = np.ones(n_features, dtype=dtype) * n_samples

# counts[j] contains the number of samples where feature j is non-zero
cdef np.ndarray[int, ndim=1] counts = np.zeros(n_features,
dtype=np.int32)

for i in xrange(non_zero):
col_ind = X_indices[i]
means[col_ind] += X_data[i]
x_i = X_data[i]
if ignore_nan and isnan(x_i):
n_samples_feat[col_ind] -= 1
continue
means[col_ind] += x_i

means /= n_samples
with np.errstate(divide="ignore"):
# as division by 0 might happen
means /= n_samples_feat
means[np.logical_not(n_samples_feat)] = 0

for i in xrange(non_zero):
col_ind = X_indices[i]
diff = X_data[i] - means[col_ind]
x_i = X_data[i]
if ignore_nan and isnan(x_i):
continue
diff = x_i - means[col_ind]
variances[col_ind] += diff * diff
counts[col_ind] += 1

for i in xrange(n_features):
variances[i] += (n_samples - counts[i]) * means[i] ** 2
variances[i] /= n_samples
variances += (n_samples_feat - counts) * means ** 2
with np.errstate(divide="ignore"):
# as division by 0 might happen
variances /= n_samples_feat
variances[np.logical_not(n_samples_feat)] = 0

return means, variances


def csc_mean_variance_axis0(X):
def csc_mean_variance_axis0(X, ignore_nan=True):
"""Compute mean and variance along axis 0 on a CSC matrix

Parameters
Expand All @@ -146,13 +164,14 @@ def csc_mean_variance_axis0(X):
"""
if X.dtype != np.float32:
X = X.astype(np.float64)
return _csc_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr)
return _csc_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr,
ignore_nan)


def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
shape,
np.ndarray[int, ndim=1] X_indices,
np.ndarray[int, ndim=1] X_indptr):
np.ndarray[int, ndim=1] X_indptr, ignore_nan=True):
# Implement the function here since variables using fused types
# cannot be declared directly and can only be passed as function arguments
cdef unsigned int n_samples = shape[0]
Expand All @@ -164,6 +183,7 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
cdef unsigned int startptr
cdef unsigned int endptr
cdef floating diff
cdef floating n_samples_feat

# means[j] contains the mean of feature j
cdef np.ndarray[floating, ndim=1] means
Expand All @@ -182,17 +202,32 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
startptr = X_indptr[i]
endptr = X_indptr[i + 1]
counts = endptr - startptr
n_samples_feat = n_samples

for j in xrange(startptr, endptr):
means[i] += X_data[j]
means[i] /= n_samples
x_j = X_data[j]
if ignore_nan and isnan(x_j):
n_samples_feat -= 1
continue
means[i] += x_j

if n_samples_feat != 0:
means[i] /= n_samples_feat
else:
means[i] = 0

for j in xrange(startptr, endptr):
diff = X_data[j] - means[i]
x_j = X_data[j]
if ignore_nan and isnan(x_j):
continue;
diff = x_j - means[i]
variances[i] += diff * diff

variances[i] += (n_samples - counts) * means[i] * means[i]
variances[i] /= n_samples
if n_samples_feat != 0:
variances[i] /= n_samples_feat
else:
variances[i] = 0

return means, variances

Expand Down
Loading