From b1c5a216e0e7a58bfe4012a1a83b804c805455c1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Jun 2018 11:58:34 +0200 Subject: [PATCH 01/13] EHN accept to fit sparse matrices --- sklearn/preprocessing/data.py | 66 +++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 72828ff92ba2e..62c1f1f5fea29 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1046,18 +1046,6 @@ def __init__(self, with_centering=True, with_scaling=True, self.quantile_range = quantile_range self.copy = copy - def _check_array(self, X, copy): - """Makes sure centering is not enabled for sparse matrices.""" - X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - estimator=self, dtype=FLOAT_DTYPES) - - if sparse.issparse(X): - if self.with_centering: - raise ValueError( - "Cannot center sparse matrices: use `with_centering=False`" - " instead. See docstring for motivation and alternatives.") - return X - def fit(self, X, y=None): """Compute the median and quantiles to be used for scaling. @@ -1067,21 +1055,43 @@ def fit(self, X, y=None): The data used to compute the median and quantiles used for later scaling along the features axis. """ + X = check_array(X, accept_sparse='csc', copy=self.copy,estimator=self, + dtype=FLOAT_DTYPES) + + q_min, q_max = self.quantile_range + if not 0 <= q_min <= q_max <= 100: + raise ValueError("Invalid quantile range: %s" % + str(self.quantile_range)) + if sparse.issparse(X): - raise TypeError("RobustScaler cannot be fitted on sparse inputs") - X = self._check_array(X, self.copy) - if self.with_centering: - self.center_ = np.median(X, axis=0) + if self.with_centering: + raise ValueError( + "Cannot center sparse matrices: use `with_centering=False`" + " instead. See docstring for motivation and alternatives.") - if self.with_scaling: - q_min, q_max = self.quantile_range - if not 0 <= q_min <= q_max <= 100: - raise ValueError("Invalid quantile range: %s" % - str(self.quantile_range)) + if self.with_scaling: + for feature_idx in range(X.shape[1]): + column_nnz = X.data[X.indptr[feature_idx]: + X.indptr[feature_idx + 1]] + quantiles = ([0, 0] if not column_nnz.size + else nanpercentile(column_nnz, + self.quantile_range)) + quantiles = np.transpose(quantiles) + else: + if self.with_centering: + self.center_ = np.nanmedian(X, axis=0) + else: + self.center_ = None - q = np.percentile(X, self.quantile_range, axis=0) - self.scale_ = (q[1] - q[0]) + if self.with_scaling: + quantiles = np.percentile(X, self.quantile_range, axis=0) + + if self.with_scaling: + self.scale_ = (quantiles[1] - quantiles[0]) self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False) + else: + self.scale_ = None + return self def transform(self, X): @@ -1095,10 +1105,7 @@ def transform(self, X): X : {array-like, sparse matrix} The data used to scale along the specified axis. """ - if self.with_centering: - check_is_fitted(self, 'center_') - if self.with_scaling: - check_is_fitted(self, 'scale_') + check_is_fitted(self, 'center_', 'scale_') X = self._check_array(X, self.copy) if sparse.issparse(X): @@ -1119,10 +1126,7 @@ def inverse_transform(self, X): X : array-like The data used to scale along the specified axis. """ - if self.with_centering: - check_is_fitted(self, 'center_') - if self.with_scaling: - check_is_fitted(self, 'scale_') + check_is_fitted(self, 'center_', 'scale_') X = self._check_array(X, self.copy) if sparse.issparse(X): From 196892eef497575f427751fd1435ef9dec81767b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Jun 2018 12:03:35 +0200 Subject: [PATCH 02/13] fix --- sklearn/preprocessing/data.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 62c1f1f5fea29..5778640bbc591 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1055,7 +1055,9 @@ def fit(self, X, y=None): The data used to compute the median and quantiles used for later scaling along the features axis. """ - X = check_array(X, accept_sparse='csc', copy=self.copy,estimator=self, + # convert sparse matrices to csc for optimized computation of the + # quantiles. + X = check_array(X, accept_sparse='csc', copy=self.copy, estimator=self, dtype=FLOAT_DTYPES) q_min, q_max = self.quantile_range @@ -1106,7 +1108,8 @@ def transform(self, X): The data used to scale along the specified axis. """ check_is_fitted(self, 'center_', 'scale_') - X = self._check_array(X, self.copy) + X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, + estimator=self, dtype=FLOAT_DTYPES) if sparse.issparse(X): if self.with_scaling: @@ -1127,7 +1130,8 @@ def inverse_transform(self, X): The data used to scale along the specified axis. """ check_is_fitted(self, 'center_', 'scale_') - X = self._check_array(X, self.copy) + X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, + estimator=self, dtype=FLOAT_DTYPES) if sparse.issparse(X): if self.with_scaling: From 36273e0fdafc934004b320ba72c65ceb5d7f425a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Jun 2018 13:18:54 +0200 Subject: [PATCH 03/13] iter --- sklearn/preprocessing/data.py | 47 ++++++++++++---------- sklearn/preprocessing/tests/test_common.py | 6 ++- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 5778640bbc591..a5d9f091e819a 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1058,37 +1058,37 @@ def fit(self, X, y=None): # convert sparse matrices to csc for optimized computation of the # quantiles. X = check_array(X, accept_sparse='csc', copy=self.copy, estimator=self, - dtype=FLOAT_DTYPES) + dtype=FLOAT_DTYPES, force_all_finite='allow-nan') q_min, q_max = self.quantile_range if not 0 <= q_min <= q_max <= 100: raise ValueError("Invalid quantile range: %s" % str(self.quantile_range)) - if sparse.issparse(X): - if self.with_centering: + if self.with_centering: + if sparse.issparse(X): raise ValueError( "Cannot center sparse matrices: use `with_centering=False`" " instead. See docstring for motivation and alternatives.") - - if self.with_scaling: - for feature_idx in range(X.shape[1]): - column_nnz = X.data[X.indptr[feature_idx]: - X.indptr[feature_idx + 1]] - quantiles = ([0, 0] if not column_nnz.size - else nanpercentile(column_nnz, - self.quantile_range)) - quantiles = np.transpose(quantiles) + self.center_ = np.nanmedian(X, axis=0) else: - if self.with_centering: - self.center_ = np.nanmedian(X, axis=0) - else: - self.center_ = None - - if self.with_scaling: - quantiles = np.percentile(X, self.quantile_range, axis=0) + self.center_ = None if self.with_scaling: + quantiles = [] + for feature_idx in range(X.shape[1]): + if sparse.issparse(X): + column_data = X.data[X.indptr[feature_idx]: + X.indptr[feature_idx + 1]] + else: + column_data = X[:, feature_idx] + + quantiles.append( + nanpercentile(column_data, self.quantile_range) + if column_data.size else [0, 0]) + + quantiles = np.transpose(quantiles) + self.scale_ = (quantiles[1] - quantiles[0]) self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False) else: @@ -1109,7 +1109,8 @@ def transform(self, X): """ check_is_fitted(self, 'center_', 'scale_') X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - estimator=self, dtype=FLOAT_DTYPES) + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): if self.with_scaling: @@ -1131,7 +1132,8 @@ def inverse_transform(self, X): """ check_is_fitted(self, 'center_', 'scale_') X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - estimator=self, dtype=FLOAT_DTYPES) + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): if self.with_scaling: @@ -1204,7 +1206,8 @@ def robust_scale(X, axis=0, with_centering=True, with_scaling=True, (e.g. as part of a preprocessing :class:`sklearn.pipeline.Pipeline`). """ X = check_array(X, accept_sparse=('csr', 'csc'), copy=False, - ensure_2d=False, dtype=FLOAT_DTYPES) + ensure_2d=False, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') original_ndim = X.ndim if original_ndim == 1: diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index d04119db8b0ba..5a6b6db4d7790 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -10,9 +10,11 @@ from sklearn.preprocessing import minmax_scale from sklearn.preprocessing import quantile_transform +from sklearn.preprocessing import robust_scale from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import QuantileTransformer +from sklearn.preprocessing import RobustScaler from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_allclose @@ -28,7 +30,9 @@ def _get_valid_samples_by_column(X, col): @pytest.mark.parametrize( "est, func, support_sparse", [(MinMaxScaler(), minmax_scale, False), - (QuantileTransformer(n_quantiles=10), quantile_transform, True)] + (QuantileTransformer(n_quantiles=10), quantile_transform, True), + (RobustScaler(), robust_scale, False), + (RobustScaler(with_centering=False), robust_scale, True)] ) def test_missing_value_handling(est, func, support_sparse): # check that the preprocessing method let pass nan From ed251634a1920aef4360178ce4cb2fdd5ff2bd06 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Jun 2018 13:51:41 +0200 Subject: [PATCH 04/13] TST check attributes and corner case sparse matrix --- sklearn/preprocessing/data.py | 4 +-- sklearn/preprocessing/tests/test_data.py | 46 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index a5d9f091e819a..ddb0bce33a56b 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1055,8 +1055,8 @@ def fit(self, X, y=None): The data used to compute the median and quantiles used for later scaling along the features axis. """ - # convert sparse matrices to csc for optimized computation of the - # quantiles. + # at fit, convert sparse matrices to csc for optimized computation of + # the quantiles X = check_array(X, accept_sparse='csc', copy=self.copy, estimator=self, dtype=FLOAT_DTYPES, force_all_finite='allow-nan') diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 800df3bab4b23..2561962bad598 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -892,6 +892,52 @@ def test_scale_input_finiteness_validation(): scale, X) +def test_robust_scaler_error_sparse(): + X_sparse = sparse.rand(1000, 10) + scaler = RobustScaler(with_centering=True) + err_msg = "Cannot center sparse matrices" + with pytest.raises(ValueError, match=err_msg): + scaler.fit(X_sparse) + + +@pytest.mark.parametrize("with_centering", [True, False]) +@pytest.mark.parametrize("with_scaling", [True, False]) +@pytest.mark.parametrize("X", [np.random.randn(10, 3), + sparse.rand(10, 3, density=0.5)]) +def test_robust_scaler_attributes(X, with_centering, with_scaling): + # check consistent type of attributes + if with_centering and sparse.issparse(X): + pytest.skip("RobustScaler cannot center sparse matrix") + + scaler = RobustScaler(with_centering=with_centering, + with_scaling=with_scaling) + scaler.fit(X) + + if with_centering: + assert isinstance(scaler.center_, np.ndarray) + else: + assert scaler.center_ is None + if with_scaling: + assert isinstance(scaler.scale_, np.ndarray) + else: + assert scaler.scale_ is None + + +def test_robust_scaler_col_zero_sparse(): + # check that the scaler is working when there is not data materialized in a + # column of a sparse matrix + X = np.random.randn(10, 5) + X[:, 0] = 0 + X = sparse.csr_matrix(X) + + scaler = RobustScaler(with_centering=False) + scaler.fit(X) + assert scaler.scale_[0] == pytest.approx(1) + + X_trans = scaler.transform(X) + assert_allclose(X[:, 0].toarray(), X_trans[:, 0].toarray()) + + def test_robust_scaler_2d_arrays(): # Test robust scaling of 2d array along first axis rng = np.random.RandomState(0) From d97fee53ac6ab3df656558e6a57641236ec87085 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Jun 2018 13:54:06 +0200 Subject: [PATCH 05/13] DOC whats new entry --- doc/whats_new/v0.20.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 32b4ef3098263..0061e4f18eb49 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -230,6 +230,11 @@ Preprocessing :issue:`10404` and :issue:`11243` by :user:`Lucija Gregov ` and :user:`Guillaume Lemaitre `. +- :class:`preprocessing.RobustScaler` and :func:`preprocessing.robust_scale` + ignore and pass-through NaN values. In addition, the scaler can now be fitted + on sparse matrices. + :issue:`11308` by :user:`Guillaume Lemaitre `. + Model evaluation and meta-estimators - A scorer based on :func:`metrics.brier_score_loss` is also available. From 296bfd0af8aaa23d867880e936e77187bb818a61 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Jun 2018 15:58:41 +0200 Subject: [PATCH 06/13] TST check equivalence between sparse and dense --- sklearn/preprocessing/data.py | 6 ++++-- sklearn/preprocessing/tests/test_data.py | 14 ++++++++++++++ sklearn/utils/estimator_checks.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index ddb0bce33a56b..038e2da11731d 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1078,8 +1078,10 @@ def fit(self, X, y=None): quantiles = [] for feature_idx in range(X.shape[1]): if sparse.issparse(X): - column_data = X.data[X.indptr[feature_idx]: - X.indptr[feature_idx + 1]] + column_nnz_data = X.data[X.indptr[feature_idx]: + X.indptr[feature_idx + 1]] + column_data = np.zeros(shape=X.shape[0], dtype=X.dtype) + column_data[:len(column_nnz_data)] = column_nnz_data else: column_data = X[:, feature_idx] diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 2561962bad598..8c768b656c357 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -951,6 +951,20 @@ def test_robust_scaler_2d_arrays(): assert_array_almost_equal(X_scaled.std(axis=0)[0], 0) +def test_robust_scaler_equivalence_dense_sparse(): + # Check the equivalence of the fitting with dense and sparse matrices + X_sparse = sparse.rand(1000, 5, density=0.5).tocsc() + X_dense = X_sparse.toarray() + + scaler_sparse = RobustScaler(with_centering=False) + scaler_dense = RobustScaler(with_centering=False) + + scaler_sparse.fit(X_sparse) + scaler_dense.fit(X_dense) + + assert_allclose(scaler_sparse.scale_, scaler_dense.scale_) + + def test_robust_scaler_transform_one_row_csr(): # Check RobustScaler on transforming csr matrix with one row rng = np.random.RandomState(0) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 2125067c502bf..006cc379dd6f1 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -77,7 +77,7 @@ 'RandomForestRegressor', 'Ridge', 'RidgeCV'] ALLOW_NAN = ['Imputer', 'SimpleImputer', 'MICEImputer', - 'MinMaxScaler', 'QuantileTransformer'] + 'MinMaxScaler', 'QuantileTransformer', 'RobustScaler'] def _yield_non_meta_checks(name, estimator): From b6f1df691b168ca4fb2c69fa5f095f51b2e206ec Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 18 Jun 2018 17:28:09 +0200 Subject: [PATCH 07/13] FIX back-port nanmedian --- sklearn/preprocessing/data.py | 4 ++-- sklearn/utils/fixes.py | 13 +++++++++++++ sklearn/utils/tests/test_fixes.py | 17 +++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 038e2da11731d..c3fe97ee168d2 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -25,7 +25,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var -from ..utils.fixes import _argmax, nanpercentile +from ..utils.fixes import _argmax, nanpercentile, nanmedian from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -1070,7 +1070,7 @@ def fit(self, X, y=None): raise ValueError( "Cannot center sparse matrices: use `with_centering=False`" " instead. See docstring for motivation and alternatives.") - self.center_ = np.nanmedian(X, axis=0) + self.center_ = nanmedian(X, axis=0) else: self.center_ = None diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index dae4ce66f16f8..763e2cb320ea2 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -267,3 +267,16 @@ def nanpercentile(a, q): return np.array([np.nan] * size_q) else: from numpy import nanpercentile # noqa + + +if np_version < (1, 9): + def nanmedian(a, axis=None): + if axis is None: + data = a.reshape(-1) + return np.median(np.compress(~np.isnan(data), data)) + else: + data = a.T if not axis else a + return np.array([np.median(np.compress(~np.isnan(row), row)) + for row in data]) +else: + from numpy import nanmedian # noqa diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index 8a55f74a4f6e3..92f954439f797 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -14,6 +14,7 @@ from sklearn.utils.fixes import divide from sklearn.utils.fixes import MaskedArray +from sklearn.utils.fixes import nanmedian from sklearn.utils.fixes import nanpercentile @@ -31,6 +32,22 @@ def test_masked_array_obj_dtype_pickleable(): assert_array_equal(marr.mask, marr_pickled.mask) +@pytest.mark.parametrize( + "axis, expected_median", + [(None, 4.0), + (0, np.array([1., 3.5, 3.5, 4., 7., np.nan])), + (1, np.array([1., 6.]))] +) +def test_nanmedian(axis, expected_median): + X = np.array([[1, 1, 1, 2, np.nan, np.nan], + [np.nan, 6, 6, 6, 7, np.nan]]) + median = nanmedian(X, axis=axis) + if axis is None: + assert median == pytest.approx(expected_median) + else: + assert_allclose(median, expected_median) + + @pytest.mark.parametrize( "a, q, expected_percentile", [(np.array([1, 2, 3, np.nan]), [0, 50, 100], np.array([1., 2., 3.])), From db7cb4b99f6953ffad60e55b16f9f2c67c7288d1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 21 Jun 2018 22:46:11 +0200 Subject: [PATCH 08/13] TST add more test case for sparse matrices --- sklearn/preprocessing/data.py | 3 --- sklearn/preprocessing/tests/test_data.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 33b51bceee69d..b7e29f623cbdd 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1137,9 +1137,6 @@ def fit(self, X, y=None): def transform(self, X): """Center and scale the data. - Can be called on sparse input, provided that ``RobustScaler`` has been - fitted to dense input and ``with_centering=False``. - Parameters ---------- X : {array-like, sparse matrix} diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 6e993dc0db78d..aaa3b80f79b71 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -965,9 +965,17 @@ def test_robust_scaler_2d_arrays(): assert_array_almost_equal(X_scaled.std(axis=0)[0], 0) -def test_robust_scaler_equivalence_dense_sparse(): +@pytest.mark.parametrize("density", [0, 0.01, 0.05, 0.1, 0.2, 0.5, 1]) +@pytest.mark.parametrize("strictly_signed", ['positif', 'negatif', 'zeros']) +def test_robust_scaler_equivalence_dense_sparse(density, strictly_signed): # Check the equivalence of the fitting with dense and sparse matrices - X_sparse = sparse.rand(1000, 5, density=0.5).tocsc() + X_sparse = sparse.rand(1000, 5, density=density).tocsc() + if strictly_signed == 'positif': + X_sparse.data += X_sparse.min() + elif strictly_signed == 'negatif': + X_sparse.data -= X_sparse.max() + else: + X_sparse.data = np.zeros(X_sparse.data.shape, dtype=np.float64) X_dense = X_sparse.toarray() scaler_sparse = RobustScaler(with_centering=False) From f884532a61654a8033ab698bcbc93b0d0bdac8e8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 21 Jun 2018 22:50:25 +0200 Subject: [PATCH 09/13] TST additional test for random sparse matrix --- sklearn/preprocessing/tests/test_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index aaa3b80f79b71..c48f32ddc7ec9 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -966,7 +966,8 @@ def test_robust_scaler_2d_arrays(): @pytest.mark.parametrize("density", [0, 0.01, 0.05, 0.1, 0.2, 0.5, 1]) -@pytest.mark.parametrize("strictly_signed", ['positif', 'negatif', 'zeros']) +@pytest.mark.parametrize("strictly_signed", + ['positif', 'negatif', 'zeros', None]) def test_robust_scaler_equivalence_dense_sparse(density, strictly_signed): # Check the equivalence of the fitting with dense and sparse matrices X_sparse = sparse.rand(1000, 5, density=density).tocsc() @@ -974,7 +975,7 @@ def test_robust_scaler_equivalence_dense_sparse(density, strictly_signed): X_sparse.data += X_sparse.min() elif strictly_signed == 'negatif': X_sparse.data -= X_sparse.max() - else: + elif strictly_signed == 'zeros': X_sparse.data = np.zeros(X_sparse.data.shape, dtype=np.float64) X_dense = X_sparse.toarray() From 7f22b3f26d9c814afa9a0a53a0f96ed59718cea7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 22 Jun 2018 12:47:39 +0200 Subject: [PATCH 10/13] address comments --- sklearn/preprocessing/tests/test_data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index c48f32ddc7ec9..2ff9dfd776a03 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -965,16 +965,16 @@ def test_robust_scaler_2d_arrays(): assert_array_almost_equal(X_scaled.std(axis=0)[0], 0) -@pytest.mark.parametrize("density", [0, 0.01, 0.05, 0.1, 0.2, 0.5, 1]) +@pytest.mark.parametrize("density", [0, 0.05, 0.1, 0.5, 1]) @pytest.mark.parametrize("strictly_signed", - ['positif', 'negatif', 'zeros', None]) + ['positive', 'negative', 'zeros', None]) def test_robust_scaler_equivalence_dense_sparse(density, strictly_signed): # Check the equivalence of the fitting with dense and sparse matrices X_sparse = sparse.rand(1000, 5, density=density).tocsc() - if strictly_signed == 'positif': - X_sparse.data += X_sparse.min() - elif strictly_signed == 'negatif': - X_sparse.data -= X_sparse.max() + if strictly_signed == 'positive': + X_sparse.data = np.abs(X_sparse.data) + elif strictly_signed == 'negative': + X_sparse.data = - np.abs(X_sparse.data) elif strictly_signed == 'zeros': X_sparse.data = np.zeros(X_sparse.data.shape, dtype=np.float64) X_dense = X_sparse.toarray() From eecb39aa4198199c1cbbdf28e435ef6d1826d681 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 26 Jun 2018 18:18:51 +0200 Subject: [PATCH 11/13] address comments --- doc/whats_new/v0.20.rst | 39 ++++++++++++++++------------------- sklearn/preprocessing/data.py | 2 +- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 48857a806d45f..1b7a90008acfc 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -260,31 +260,28 @@ Preprocessing classes found which are ignored. :issue:`10913` by :user:`Rodrigo Agundez `. -- :class:`preprocessing.QuantileTransformer` handles and ignores NaN values. - :issue:`10404` by :user:`Guillaume Lemaitre `. - -- Updated :class:`preprocessing.MinMaxScaler` and - :func:`preprocessing.minmax_scale` to pass through NaN values. - :issue:`10404` and :issue:`11243` by :user:`Lucija Gregov ` and +- NaN values are ignored and handled in the following preprocessing methods: + :class:`preprocessing.MaxAbsScaler`, + :class:`preprocessing.MinMaxScaler`, + :class:`preprocessing.RobustScaler`, + :class:`preprocessing.StandardScaler`, + :class:`preprocessing.PowerTransformer`, + :class:`preprocessing.QuantileTransformer` classes and + :func:`preprocessing.maxabs_scale`, + :func:`preprocessing.minmax_scale`, + :func:`preprocessing.robust_scale`, + :func:`preprocessing.scale`, + :func:`preprocessing.power_transform`, + :func:`preprocessing.quantile_transform` functions respectively addressed in + issues :issue:`11011`, :issue:`11005`, :issue:`11308`, :issue:`11206`, + :issue:`11306`, and :issue:`10437`. + By :user:`Lucija Gregov ` and :user:`Guillaume Lemaitre `. -- :class:`preprocessing.StandardScaler` and :func:`preprocessing.scale` - ignore and pass-through NaN values. - :issue:`11206` by :user:`Guillaume Lemaitre `. - -- :class:`preprocessing.MaxAbsScaler` and :func:`preprocessing.maxabs_scale` - handles and ignores NaN values. - :issue:`11011` by `Lucija Gregov ` and - :user:`Guillaume Lemaitre ` - -- :class:`preprocessing.PowerTransformer` and - :func:`preprocessing.power_transform` ignore and pass-through NaN values. - :issue:`11306` by :user:`Guillaume Lemaitre `. - :class:`preprocessing.RobustScaler` and :func:`preprocessing.robust_scale` - ignore and pass-through NaN values. In addition, the scaler can now be fitted - on sparse matrices. + can be fitted using sparse matrices. :issue:`11308` by :user:`Guillaume Lemaitre `. - + Model evaluation and meta-estimators - A scorer based on :func:`metrics.brier_score_loss` is also available. diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 0ea75564077d4..3c17b76f0b710 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1137,7 +1137,7 @@ def fit(self, X, y=None): quantiles = np.transpose(quantiles) - self.scale_ = (quantiles[1] - quantiles[0]) + self.scale_ = quantiles[1] - quantiles[0] self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False) else: self.scale_ = None From 02a981116b1da6c21f7945638e1a7c9f3caec36e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 27 Jun 2018 16:05:30 +0200 Subject: [PATCH 12/13] joel comments --- sklearn/preprocessing/data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 3c17b76f0b710..ff20a68760782 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1131,9 +1131,7 @@ def fit(self, X, y=None): else: column_data = X[:, feature_idx] - quantiles.append( - nanpercentile(column_data, self.quantile_range) - if column_data.size else [0, 0]) + quantiles.append(nanpercentile(column_data, self.quantile_range)) quantiles = np.transpose(quantiles) From 86cf70767509f123264afaa6adf8a4c705852f32 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 27 Jun 2018 16:30:36 +0200 Subject: [PATCH 13/13] Update data.py --- sklearn/preprocessing/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index ff20a68760782..e3c72d6884591 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1131,7 +1131,8 @@ def fit(self, X, y=None): else: column_data = X[:, feature_idx] - quantiles.append(nanpercentile(column_data, self.quantile_range)) + quantiles.append(nanpercentile(column_data, + self.quantile_range)) quantiles = np.transpose(quantiles)