diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index bbd2fae10c0ec..d4d796a9bfcb5 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -26,7 +26,8 @@ from ..utils.extmath import _incremental_mean_and_var from ..utils.fixes import _argmax from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, - inplace_csr_row_normalize_l2) + inplace_csr_row_normalize_l2, + n_samples_count_csc, n_samples_count_csr) from ..utils.sparsefuncs import (inplace_column_scale, mean_variance_axis, incr_mean_variance_axis, min_max_axis) @@ -619,7 +620,8 @@ def partial_fit(self, X, y=None): Ignored """ X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - warn_on_dtype=True, estimator=self, dtype=FLOAT_DTYPES) + warn_on_dtype=True, estimator=self, + force_all_finite='allow-nan', dtype=FLOAT_DTYPES) # Even in the case of `with_mean=False`, we update the mean anyway # This is needed for the incremental computation of the var @@ -634,14 +636,23 @@ def partial_fit(self, X, y=None): # First pass if not hasattr(self, 'n_samples_seen_'): self.mean_, self.var_ = mean_variance_axis(X, axis=0) - self.n_samples_seen_ = X.shape[0] + if isinstance(X, sparse.csc_matrix): + self.n_samples_seen_ = \ + n_samples_count_csc(X.data, X.shape, + X.indices, X.indptr) + else: + self.n_samples_seen_ = \ + n_samples_count_csr(X.data, X.shape, X.indices) + # Next passes else: self.mean_, self.var_, self.n_samples_seen_ = \ - incr_mean_variance_axis(X, axis=0, - last_mean=self.mean_, - last_var=self.var_, - last_n=self.n_samples_seen_) + incr_mean_variance_axis( + X, axis=0, + last_mean=self.mean_, + last_var=self.var_, + last_n=0, + last_n_feat=self.n_samples_seen_) else: self.mean_ = None self.var_ = None @@ -649,7 +660,7 @@ def partial_fit(self, X, y=None): # First pass if not hasattr(self, 'n_samples_seen_'): self.mean_ = .0 - self.n_samples_seen_ = 0 + self.n_samples_seen_ = np.zeros(X.shape[1]) if self.with_std: self.var_ = .0 else: @@ -688,7 +699,8 @@ def transform(self, X, y='deprecated', copy=None): copy = copy if copy is not None else self.copy X = check_array(X, accept_sparse='csr', copy=copy, warn_on_dtype=True, - estimator=self, dtype=FLOAT_DTYPES) + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): if self.with_mean: diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f4358e48fc0b8..7e91e6aabac2b 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -203,7 +203,7 @@ def test_standard_scaler_1d(): np.zeros_like(n_features)) assert_array_almost_equal(X_scaled.mean(axis=0), .0) assert_array_almost_equal(X_scaled.std(axis=0), 1.) - assert_equal(scaler.n_samples_seen_, X.shape[0]) + assert_equal(scaler.n_samples_seen_[0], X.shape[0]) # check inverse transform X_scaled_back = scaler.inverse_transform(X_scaled) @@ -283,7 +283,7 @@ def test_scaler_2d_arrays(): scaler = StandardScaler() X_scaled = scaler.fit(X).transform(X, copy=True) assert_false(np.any(np.isnan(X_scaled))) - assert_equal(scaler.n_samples_seen_, n_samples) + assert_equal(scaler.n_samples_seen_[0], n_samples) assert_array_almost_equal(X_scaled.mean(axis=0), n_features * [0.0]) assert_array_almost_equal(X_scaled.std(axis=0), [0., 1., 1., 1., 1.]) @@ -399,7 +399,7 @@ def test_standard_scaler_partial_fit(): assert_array_almost_equal(scaler_batch.mean_, scaler_incr.mean_) assert_equal(scaler_batch.var_, scaler_incr.var_) # Nones - assert_equal(scaler_batch.n_samples_seen_, scaler_incr.n_samples_seen_) + assert_equal(scaler_batch.n_samples_seen_[0], scaler_incr.n_samples_seen_[0]) # Test std after 1 step batch0 = slice(0, chunk_size) @@ -423,10 +423,10 @@ def test_standard_scaler_partial_fit(): assert_correct_incr(i, batch_start=batch.start, batch_stop=batch.stop, n=n, chunk_size=chunk_size, - n_samples_seen=scaler_incr.n_samples_seen_) + n_samples_seen=scaler_incr.n_samples_seen_[0]) assert_array_almost_equal(scaler_batch.var_, scaler_incr.var_) - assert_equal(scaler_batch.n_samples_seen_, scaler_incr.n_samples_seen_) + assert_equal(scaler_batch.n_samples_seen_[0], scaler_incr.n_samples_seen_[0]) def test_standard_scaler_partial_fit_numerical_stability(): @@ -515,7 +515,7 @@ def test_standard_scaler_trasform_with_partial_fit(): assert_array_less(zero, scaler_incr.var_ + epsilon) # as less or equal assert_array_less(zero, scaler_incr.scale_ + epsilon) # (i+1) because the Scaler has been already fitted - assert_equal((i + 1), scaler_incr.n_samples_seen_) + assert_equal((i + 1), scaler_incr.n_samples_seen_[0]) def test_min_max_scaler_iris(): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 0658eca0a371f..15faae83911ab 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -70,6 +70,7 @@ 'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression', 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] +ALLOW_NAN = ['StandardScaler'] def _yield_non_meta_checks(name, estimator): @@ -1024,6 +1025,8 @@ def check_estimators_nan_inf(name, estimator_orig): error_string_transform = ("Estimator doesn't check for NaN and inf in" " transform.") for X_train in [X_train_nan, X_train_inf]: + if np.any(np.isnan(X_train)) and name in ALLOW_NAN: + continue # catch deprecation warnings with ignore_warnings(category=(DeprecationWarning, FutureWarning)): estimator = clone(estimator_orig) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index e95ceb57497ae..56aff84586f36 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -643,7 +643,7 @@ def make_nonnegative(X, min_value=0): def _incremental_mean_and_var(X, last_mean=.0, last_variance=None, - last_sample_count=0): + last_sample_count=0, ignore_nan=True): """Calculate mean update and a Youngs and Cramer variance update. last_mean and last_variance are statistics computed at the last step by the @@ -688,29 +688,55 @@ def _incremental_mean_and_var(X, last_mean=.0, last_variance=None, # old = stats until now # new = the current increment # updated = the aggregated stats + flag = 0 # if flag == 1 then last_sample_count was an array + n_features = X.shape[1] + if isinstance(last_sample_count, np.ndarray): + flag = 1 + else: + last_sample_count *= np.ones(n_features) + last_sum = last_mean * last_sample_count - new_sum = X.sum(axis=0) + sum_func = np.nansum if ignore_nan else np.sum + new_sum = sum_func(X, axis=0) + new_sum[np.isnan(new_sum)] = 0 - new_sample_count = X.shape[0] + new_sample_count = np.count_nonzero(~np.isnan(X), axis=0) + if not isinstance(new_sample_count, np.ndarray): + new_sample_count *= np.ones(n_features) updated_sample_count = last_sample_count + new_sample_count updated_mean = (last_sum + new_sum) / updated_sample_count + updated_variance = np.zeros(n_features) if last_variance is None: updated_variance = None else: - new_unnormalized_variance = X.var(axis=0) * new_sample_count - if last_sample_count == 0: # Avoid division by 0 - updated_unnormalized_variance = new_unnormalized_variance - else: - last_over_new_count = last_sample_count / new_sample_count - last_unnormalized_variance = last_variance * last_sample_count - updated_unnormalized_variance = ( - last_unnormalized_variance + - new_unnormalized_variance + - last_over_new_count / updated_sample_count * - (last_sum / last_over_new_count - new_sum) ** 2) - updated_variance = updated_unnormalized_variance / updated_sample_count + var_func = np.nanvar if ignore_nan else np.var + new_unnormalized_variance = var_func(X, axis=0) + new_unnormalized_variance[np.isnan(new_unnormalized_variance)] = 0 + new_unnormalized_variance = (new_unnormalized_variance * + new_sample_count) + for i in xrange(n_features): + if updated_sample_count[i] == 0: # Avoid division by 0 + continue + # Avoid division by 0 + elif last_sample_count[i] == 0 or new_sample_count[i] == 0: + updated_unnormalized_variance = new_unnormalized_variance[i] + else: + last_over_new_count = (last_sample_count[i] / + new_sample_count[i]) + last_unnormalized_variance = (last_variance[i] * + last_sample_count[i]) + updated_unnormalized_variance = ( + last_unnormalized_variance + + new_unnormalized_variance[i] + + last_over_new_count / updated_sample_count[i] * + (last_sum[i] / last_over_new_count - new_sum[i]) ** 2) + updated_variance[i] = (updated_unnormalized_variance / + updated_sample_count[i]) + + if flag == 0: # If n_sample_count was not an array + updated_sample_count = updated_sample_count[0] return updated_mean, updated_variance, updated_sample_count diff --git a/sklearn/utils/sparsefuncs.py b/sklearn/utils/sparsefuncs.py index 38b8b0a6eff16..ee24f1c25b09c 100644 --- a/sklearn/utils/sparsefuncs.py +++ b/sklearn/utils/sparsefuncs.py @@ -99,7 +99,8 @@ def mean_variance_axis(X, axis): _raise_typeerror(X) -def incr_mean_variance_axis(X, axis, last_mean, last_var, last_n): +def incr_mean_variance_axis(X, axis, last_mean, last_var, last_n, + last_n_feat=np.array([0], dtype=np.uint32)): """Compute incremental mean and variance along an axix on a CSR or CSC matrix. @@ -143,17 +144,21 @@ def incr_mean_variance_axis(X, axis, last_mean, last_var, last_n): if isinstance(X, sp.csr_matrix): if axis == 0: return _incr_mean_var_axis0(X, last_mean=last_mean, - last_var=last_var, last_n=last_n) + last_var=last_var, last_n=last_n, + last_n_feat=last_n_feat) else: return _incr_mean_var_axis0(X.T, last_mean=last_mean, - last_var=last_var, last_n=last_n) + last_var=last_var, last_n=last_n, + last_n_feat=last_n_feat) elif isinstance(X, sp.csc_matrix): if axis == 0: return _incr_mean_var_axis0(X, last_mean=last_mean, - last_var=last_var, last_n=last_n) + last_var=last_var, last_n=last_n, + last_n_feat=last_n_feat) else: return _incr_mean_var_axis0(X.T, last_mean=last_mean, - last_var=last_var, last_n=last_n) + last_var=last_var, last_n=last_n, + last_n_feat=last_n_feat) else: _raise_typeerror(X) diff --git a/sklearn/utils/sparsefuncs_fast.pyx b/sklearn/utils/sparsefuncs_fast.pyx index 52c12ce5d5953..ffaf1f5af770a 100644 --- a/sklearn/utils/sparsefuncs_fast.pyx +++ b/sklearn/utils/sparsefuncs_fast.pyx @@ -16,6 +16,8 @@ import scipy.sparse as sp cimport cython from cython cimport floating +from numpy.math cimport isnan + np.import_array() ctypedef fused integral: @@ -79,7 +81,8 @@ def csr_mean_variance_axis0(X): def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, shape, - np.ndarray[int, ndim=1] X_indices): + np.ndarray[int, ndim=1] X_indices, + ignore_nan=True): # Implement the function here since variables using fused types # cannot be declared directly and can only be passed as function arguments cdef unsigned int n_samples = shape[0] @@ -94,6 +97,8 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, cdef np.ndarray[floating, ndim=1] means # variances[j] contains the variance of feature j cdef np.ndarray[floating, ndim=1] variances + # n_samples_feat[j] contains the number of Non-NaN values of feature j + cdef np.ndarray[floating, ndim=1] n_samples_feat if floating is float: dtype = np.float32 @@ -102,6 +107,7 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, means = np.zeros(n_features, dtype=dtype) variances = np.zeros_like(means, dtype=dtype) + n_samples_feat = np.ones_like(means, dtype=dtype) * n_samples # counts[j] contains the number of samples where feature j is non-zero cdef np.ndarray[int, ndim=1] counts = np.zeros(n_features, @@ -109,19 +115,30 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data, for i in xrange(non_zero): col_ind = X_indices[i] - means[col_ind] += X_data[i] + x_i = X_data[i] + if isnan(x_i) and ignore_nan: + n_samples_feat[col_ind] -= 1 + continue + means[col_ind] += x_i - means /= n_samples + for i in xrange(n_features): + # Avoid division by Zero in cases when all column elements are NaN + if n_samples_feat[i]: + means[i] /= n_samples_feat[i] for i in xrange(non_zero): col_ind = X_indices[i] - diff = X_data[i] - means[col_ind] + x_i = X_data[i] + if isnan(x_i) and ignore_nan: + continue + diff = x_i - means[col_ind] variances[col_ind] += diff * diff counts[col_ind] += 1 for i in xrange(n_features): - variances[i] += (n_samples - counts[i]) * means[i] ** 2 - variances[i] /= n_samples + if n_samples_feat[i]: + variances[i] += (n_samples_feat[i] - counts[i]) * means[i] ** 2 + variances[i] /= n_samples_feat[i] return means, variances @@ -152,7 +169,8 @@ def csc_mean_variance_axis0(X): def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data, shape, np.ndarray[int, ndim=1] X_indices, - np.ndarray[int, ndim=1] X_indptr): + np.ndarray[int, ndim=1] X_indptr, + ignore_nan=True): # Implement the function here since variables using fused types # cannot be declared directly and can only be passed as function arguments cdef unsigned int n_samples = shape[0] @@ -163,6 +181,7 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data, cdef unsigned int counts cdef unsigned int startptr cdef unsigned int endptr + cdef unsigned int n_samples_feat cdef floating diff # means[j] contains the mean of feature j @@ -182,22 +201,87 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data, startptr = X_indptr[i] endptr = X_indptr[i + 1] counts = endptr - startptr + n_samples_feat = n_samples for j in xrange(startptr, endptr): - means[i] += X_data[j] - means[i] /= n_samples + if isnan(X_data[j]) and ignore_nan: + n_samples_feat -= 1 + counts -= 1 + if n_samples_feat == 0: # Avoid division by Zero + continue for j in xrange(startptr, endptr): - diff = X_data[j] - means[i] + x_i = X_data[j] + if isnan(x_i) and ignore_nan: + continue + means[i] += x_i + + means[i] /= n_samples_feat + + for j in xrange(startptr, endptr): + x_i = X_data[j] + if isnan(x_i) and ignore_nan: + continue + diff = x_i - means[i] variances[i] += diff * diff - variances[i] += (n_samples - counts) * means[i] * means[i] - variances[i] /= n_samples + variances[i] += (n_samples_feat - counts) * means[i] * means[i] + variances[i] /= n_samples_feat return means, variances -def incr_mean_variance_axis0(X, last_mean, last_var, unsigned long last_n): +def n_samples_count_csc(np.ndarray[floating, ndim=1] X_data, + shape, + np.ndarray[int, ndim=1] X_indices, + np.ndarray[int, ndim=1] X_indptr): + cdef unsigned int n_samples = shape[0] + cdef unsigned int n_features = shape[1] + cdef unsigned int startptr + cdef unsigned int endptr + cdef unsigned int i + cdef unsigned int j + + cdef np.ndarray[unsigned int, ndim=1] n_samples_feat + + n_samples_feat = np.ones(n_features, dtype=np.uint32) * n_samples + + for i in xrange(n_features): + startptr = X_indptr[i] + endptr = X_indptr[i+1] + + for j in xrange(startptr, endptr): + if isnan(X_data[j]): + n_samples_feat[i] -= 1 + + return n_samples_feat + + +def n_samples_count_csr(np.ndarray[floating, ndim=1, mode="c"] X_data, + shape, + np.ndarray[int, ndim=1] X_indices): + cdef unsigned int n_samples = shape[0] + cdef unsigned int n_features = shape[1] + + cdef unsigned int i + cdef unsigned int non_zero = X_indices.shape[0] + cdef unsigned int col_ind + + cdef np.ndarray[unsigned int, ndim=1] n_samples_feat + + n_samples_feat = np.ones(n_features, dtype=np.uint32) * n_samples + + for i in xrange(non_zero): + col_ind = X_indices[i] + x_i = X_data[i] + if isnan(x_i): + n_samples_feat[col_ind] -= 1 + + return n_samples_feat + + +def incr_mean_variance_axis0(X, last_mean, last_var, unsigned long last_n, + last_n_feat=np.array([0], dtype=np.uint32)): """Compute mean and variance along axis 0 on a CSR or CSC matrix. last_mean, last_var are the statistics computed at the last step by this @@ -244,7 +328,8 @@ def incr_mean_variance_axis0(X, last_mean, last_var, unsigned long last_n): if X.dtype != np.float32: X = X.astype(np.float64) return _incr_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr, - X.format, last_mean, last_var, last_n) + X.format, last_mean, last_var, last_n, + last_n_feat) def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data, @@ -254,7 +339,8 @@ def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data, X_format, last_mean, last_var, - unsigned long last_n): + unsigned long last_n, + np.ndarray[unsigned int, ndim=1] last_n_feat): # Implement the function here since variables using fused types # cannot be declared directly and can only be passed as function arguments cdef unsigned long n_samples = shape[0] @@ -280,7 +366,9 @@ def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data, updated_var = np.zeros_like(new_mean, dtype=dtype) cdef unsigned long new_n + cdef np.ndarray[unsigned int, ndim=1] new_n_feat cdef unsigned long updated_n + cdef np.ndarray[unsigned int, ndim=1] updated_n_feat cdef floating last_over_new_n # Obtain new stats first @@ -288,16 +376,52 @@ def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data, if X_format == 'csr': # X is a CSR matrix + new_n_feat = n_samples_count_csr(X_data, shape, X_indices) new_mean, new_var = _csr_mean_variance_axis0(X_data, shape, X_indices) else: # X is a CSC matrix + new_n_feat = n_samples_count_csc(X_data, shape, X_indices, X_indptr) new_mean, new_var = _csc_mean_variance_axis0(X_data, shape, X_indices, X_indptr) + new_n = new_n_feat[0] # First pass - if last_n == 0: + if last_n == 0 and (last_n_feat==0).all(): return new_mean, new_var, new_n # Next passes + + # Where each feature has different values and updated_n_feat is a vector + elif last_n==0 and (last_n_feat!=0).any(): + updated_n_feat = last_n_feat + new_n_feat + + for i in xrange(n_features): + if updated_n_feat[i] == 0: + continue + if new_n_feat[i] == 0: + updated_mean[i] = last_mean[i] + updated_var[i] = last_var[i] + continue + last_over_new_n = last_n_feat[i] * 1.0 / new_n_feat[i] + # Unnormalized old stats + last_mean[i] *= last_n_feat[i] + last_var[i] *= last_n_feat[i] + + # Unnormalized new stats + new_mean[i] *= new_n_feat[i] + new_var[i] *= new_n_feat[i] + + # Update stats + updated_var[i] = (last_var[i] + new_var[i] + + last_over_new_n / updated_n_feat[i] * + (last_mean[i] / last_over_new_n - + new_mean[i]) ** 2) + + updated_mean[i] = (last_mean[i] + new_mean[i]) / updated_n_feat[i] + updated_var[i] = updated_var[i] / updated_n_feat[i] + + return updated_mean, updated_var, updated_n_feat + + # Where updated_n is a scaler else: updated_n = last_n + new_n last_over_new_n = last_n / new_n diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index f53b814c70084..fe6a1aba7361e 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -5,6 +5,7 @@ # License: BSD 3 clause import numpy as np +from numpy.testing import assert_allclose from scipy import sparse from scipy import linalg from scipy import stats @@ -467,6 +468,31 @@ def naive_log_logistic(x): assert_array_almost_equal(log_logistic(extreme_x), [-100, 0]) +def test_incremental_mean_and_var_nan(): + # Test mean and variance when an array has floating NaN values + A = np.array([[600, 470, 170, 430, np.nan], + [600, np.nan, 170, 430, 300], + [np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan]]) + X1 = A[:2, :] + X2 = A[2:, :] + X_means = np.nanmean(X1, axis=0) + X_variances = np.nanvar(X1, axis=0) + X_count = np.count_nonzero(~np.isnan(X1), axis=0) + A_means = np.nanmean(A, axis=0) + A_variances = np.nanvar(A, axis=0) + A_count = np.count_nonzero(~np.isnan(A), axis=0) + + final_means, final_variances, final_count = \ + _incremental_mean_and_var(X2, X_means, X_variances, X_count) + assert_allclose(A_means, final_means, equal_nan=True) + print A_variances + print X_variances + print final_variances + assert_allclose(A_variances, final_variances, equal_nan=True) + assert_allclose(A_count, final_count, equal_nan=True) + + def test_incremental_variance_update_formulas(): # Test Youngs and Cramer incremental variance formulas. # Doggie data from http://www.mathsisfun.com/data/standard-deviation.html diff --git a/sklearn/utils/tests/test_sparsefuncs.py b/sklearn/utils/tests/test_sparsefuncs.py index f2b35e7459833..51a8fca5425bd 100644 --- a/sklearn/utils/tests/test_sparsefuncs.py +++ b/sklearn/utils/tests/test_sparsefuncs.py @@ -4,7 +4,7 @@ from scipy import linalg from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal) + assert_equal, assert_allclose) from numpy.random import RandomState from sklearn.datasets import make_classification @@ -21,6 +21,30 @@ from sklearn.utils.testing import assert_raises +def test_mean_variance_axis0_nan(): + A = np.zeros(10).reshape(5, 2) * np.nan + A[0, 0] = 1 + A_means = np.nanmean(A, axis=0) + A_means[np.isnan(A_means)] = 0 + A_vars = np.nanvar(A, axis=0) + A_vars[np.isnan(A_vars)] = 0 + X_csr = sp.csr_matrix(A) + X_csc = sp.csc_matrix(A) + + expected_dtypes = [(np.float32, np.float32), + (np.float64, np.float64)] + + for input_dtype, output_dtype in expected_dtypes: + + for X_sparse in (X_csr, X_csc): + X_sparse = X_sparse.astype(input_dtype) + X_means, X_vars = mean_variance_axis(X_sparse, axis=0) + assert_equal(X_means.dtype, output_dtype) + assert_equal(X_vars.dtype, output_dtype) + assert_allclose(X_means, A_means) + assert_allclose(X_vars, A_vars) + + def test_mean_variance_axis0(): X, _ = make_classification(5, 4, random_state=0) # Sparsify the array a little bit @@ -83,6 +107,68 @@ def test_mean_variance_axis1(): assert_array_almost_equal(X_vars, np.var(X_test, axis=0)) +def test_incr_mean_variance_axis_nan(): + for axis in [0, 1]: + n_features = 50 + data_chunk = np.random.randint(0, 2, size=500).reshape(50, 10) + data_chunk = np.array(data_chunk, dtype=np.float) + data_chunk[0, 0] = np.nan + + # default params for incr_mean_variance + last_mean = np.zeros(n_features) + last_var = np.zeros_like(last_mean) + last_n = np.zeros_like(last_mean, dtype=np.uint32) + + # Test errors + X = data_chunk[0] + X = np.atleast_2d(X) + X_lil = sp.lil_matrix(X) + X_csr = sp.csr_matrix(X_lil) + assert_raises(TypeError, incr_mean_variance_axis, axis, + last_mean, last_var, last_n) + assert_raises(TypeError, incr_mean_variance_axis, axis, + last_mean, last_var, last_n) + assert_raises(TypeError, incr_mean_variance_axis, X_lil, axis, + last_mean, last_var, last_n) + + # Test _incr_mean_and_var with a 1 row input + X_means, X_vars = mean_variance_axis(X_csr, axis) + X_means_incr, X_vars_incr, n_incr = \ + incr_mean_variance_axis(X_csr, axis, last_mean, last_var, + last_n=0, last_n_feat=last_n) + assert_allclose(X_means, X_means_incr) + assert_allclose(X_vars, X_vars_incr) + + X_csc = sp.csc_matrix(X_lil) + X_means, X_vars = mean_variance_axis(X_csc, axis) + assert_allclose(X_means, X_means_incr) + assert_allclose(X_vars, X_vars_incr) + + # Test _incremental_mean_and_var with whole data + X = data_chunk + X_lil = sp.lil_matrix(X) + X_csr = sp.csr_matrix(X_lil) + X_csc = sp.csc_matrix(X_lil) + + expected_dtypes = [(np.float32, np.float32), + (np.float64, np.float64), + (np.int32, np.float64), + (np.int64, np.float64)] + + for input_dtype, output_dtype in expected_dtypes: + for X_sparse in (X_csr, X_csc): + X_sparse = X_sparse.astype(input_dtype) + X_means, X_vars = mean_variance_axis(X_sparse, axis) + X_means_incr, X_vars_incr, n_incr = \ + incr_mean_variance_axis(X_sparse, axis, last_mean, + last_var, last_n=0, + last_n_feat=last_n) + assert_equal(X_means_incr.dtype, output_dtype) + assert_equal(X_vars_incr.dtype, output_dtype) + assert_allclose(X_means, X_means_incr) + assert_allclose(X_vars, X_vars_incr) + + def test_incr_mean_variance_axis(): for axis in [0, 1]: rng = np.random.RandomState(0)