diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index fad5db0e7a9a9..1b7a90008acfc 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -260,27 +260,28 @@ Preprocessing classes found which are ignored. :issue:`10913` by :user:`Rodrigo Agundez `. -- :class:`preprocessing.QuantileTransformer` handles and ignores NaN values. - :issue:`10404` by :user:`Guillaume Lemaitre `. - -- Updated :class:`preprocessing.MinMaxScaler` and - :func:`preprocessing.minmax_scale` to pass through NaN values. - :issue:`10404` and :issue:`11243` by :user:`Lucija Gregov ` and +- NaN values are ignored and handled in the following preprocessing methods: + :class:`preprocessing.MaxAbsScaler`, + :class:`preprocessing.MinMaxScaler`, + :class:`preprocessing.RobustScaler`, + :class:`preprocessing.StandardScaler`, + :class:`preprocessing.PowerTransformer`, + :class:`preprocessing.QuantileTransformer` classes and + :func:`preprocessing.maxabs_scale`, + :func:`preprocessing.minmax_scale`, + :func:`preprocessing.robust_scale`, + :func:`preprocessing.scale`, + :func:`preprocessing.power_transform`, + :func:`preprocessing.quantile_transform` functions respectively addressed in + issues :issue:`11011`, :issue:`11005`, :issue:`11308`, :issue:`11206`, + :issue:`11306`, and :issue:`10437`. + By :user:`Lucija Gregov ` and :user:`Guillaume Lemaitre `. -- :class:`preprocessing.StandardScaler` and :func:`preprocessing.scale` - ignore and pass-through NaN values. - :issue:`11206` by :user:`Guillaume Lemaitre `. - -- :class:`preprocessing.MaxAbsScaler` and :func:`preprocessing.maxabs_scale` - handles and ignores NaN values. - :issue:`11011` by `Lucija Gregov ` and - :user:`Guillaume Lemaitre ` - -- :class:`preprocessing.PowerTransformer` and - :func:`preprocessing.power_transform` ignore and pass-through NaN values. - :issue:`11306` by :user:`Guillaume Lemaitre `. - +- :class:`preprocessing.RobustScaler` and :func:`preprocessing.robust_scale` + can be fitted using sparse matrices. + :issue:`11308` by :user:`Guillaume Lemaitre `. + Model evaluation and meta-estimators - A scorer based on :func:`metrics.brier_score_loss` is also available. diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 7c014a07481be..e3c72d6884591 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -24,7 +24,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var -from ..utils.fixes import boxcox, nanpercentile +from ..utils.fixes import boxcox, nanpercentile, nanmedian from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -1092,18 +1092,6 @@ def __init__(self, with_centering=True, with_scaling=True, self.quantile_range = quantile_range self.copy = copy - def _check_array(self, X, copy): - """Makes sure centering is not enabled for sparse matrices.""" - X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - estimator=self, dtype=FLOAT_DTYPES) - - if sparse.issparse(X): - if self.with_centering: - raise ValueError( - "Cannot center sparse matrices: use `with_centering=False`" - " instead. See docstring for motivation and alternatives.") - return X - def fit(self, X, y=None): """Compute the median and quantiles to be used for scaling. @@ -1113,39 +1101,60 @@ def fit(self, X, y=None): The data used to compute the median and quantiles used for later scaling along the features axis. """ - if sparse.issparse(X): - raise TypeError("RobustScaler cannot be fitted on sparse inputs") - X = self._check_array(X, self.copy) + # at fit, convert sparse matrices to csc for optimized computation of + # the quantiles + X = check_array(X, accept_sparse='csc', copy=self.copy, estimator=self, + dtype=FLOAT_DTYPES, force_all_finite='allow-nan') + + q_min, q_max = self.quantile_range + if not 0 <= q_min <= q_max <= 100: + raise ValueError("Invalid quantile range: %s" % + str(self.quantile_range)) + if self.with_centering: - self.center_ = np.median(X, axis=0) + if sparse.issparse(X): + raise ValueError( + "Cannot center sparse matrices: use `with_centering=False`" + " instead. See docstring for motivation and alternatives.") + self.center_ = nanmedian(X, axis=0) + else: + self.center_ = None if self.with_scaling: - q_min, q_max = self.quantile_range - if not 0 <= q_min <= q_max <= 100: - raise ValueError("Invalid quantile range: %s" % - str(self.quantile_range)) + quantiles = [] + for feature_idx in range(X.shape[1]): + if sparse.issparse(X): + column_nnz_data = X.data[X.indptr[feature_idx]: + X.indptr[feature_idx + 1]] + column_data = np.zeros(shape=X.shape[0], dtype=X.dtype) + column_data[:len(column_nnz_data)] = column_nnz_data + else: + column_data = X[:, feature_idx] - q = np.percentile(X, self.quantile_range, axis=0) - self.scale_ = (q[1] - q[0]) + quantiles.append(nanpercentile(column_data, + self.quantile_range)) + + quantiles = np.transpose(quantiles) + + self.scale_ = quantiles[1] - quantiles[0] self.scale_ = _handle_zeros_in_scale(self.scale_, copy=False) + else: + self.scale_ = None + return self def transform(self, X): """Center and scale the data. - Can be called on sparse input, provided that ``RobustScaler`` has been - fitted to dense input and ``with_centering=False``. - Parameters ---------- X : {array-like, sparse matrix} The data used to scale along the specified axis. """ - if self.with_centering: - check_is_fitted(self, 'center_') - if self.with_scaling: - check_is_fitted(self, 'scale_') - X = self._check_array(X, self.copy) + check_is_fitted(self, 'center_', 'scale_') + X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): if self.with_scaling: @@ -1165,11 +1174,10 @@ def inverse_transform(self, X): X : array-like The data used to scale along the specified axis. """ - if self.with_centering: - check_is_fitted(self, 'center_') - if self.with_scaling: - check_is_fitted(self, 'scale_') - X = self._check_array(X, self.copy) + check_is_fitted(self, 'center_', 'scale_') + X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): if self.with_scaling: @@ -1242,7 +1250,8 @@ def robust_scale(X, axis=0, with_centering=True, with_scaling=True, (e.g. as part of a preprocessing :class:`sklearn.pipeline.Pipeline`). """ X = check_array(X, accept_sparse=('csr', 'csc'), copy=False, - ensure_2d=False, dtype=FLOAT_DTYPES) + ensure_2d=False, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') original_ndim = X.ndim if original_ndim == 1: diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 4abc73d6ef445..e8cd8d9d18db6 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -13,12 +13,14 @@ from sklearn.preprocessing import scale from sklearn.preprocessing import power_transform from sklearn.preprocessing import quantile_transform +from sklearn.preprocessing import robust_scale from sklearn.preprocessing import MaxAbsScaler from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import PowerTransformer from sklearn.preprocessing import QuantileTransformer +from sklearn.preprocessing import RobustScaler from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_allclose @@ -38,7 +40,9 @@ def _get_valid_samples_by_column(X, col): (StandardScaler(), scale, False, False), (StandardScaler(with_mean=False), scale, True, False), (PowerTransformer(), power_transform, False, True), - (QuantileTransformer(n_quantiles=10), quantile_transform, True, False)] + (QuantileTransformer(n_quantiles=10), quantile_transform, True, False), + (RobustScaler(), robust_scale, False, False), + (RobustScaler(with_centering=False), robust_scale, True, False)] ) def test_missing_value_handling(est, func, support_sparse, strictly_positive): # check that the preprocessing method let pass nan diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f90fbee278c05..2ff9dfd776a03 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -906,6 +906,52 @@ def test_scale_input_finiteness_validation(): scale, X) +def test_robust_scaler_error_sparse(): + X_sparse = sparse.rand(1000, 10) + scaler = RobustScaler(with_centering=True) + err_msg = "Cannot center sparse matrices" + with pytest.raises(ValueError, match=err_msg): + scaler.fit(X_sparse) + + +@pytest.mark.parametrize("with_centering", [True, False]) +@pytest.mark.parametrize("with_scaling", [True, False]) +@pytest.mark.parametrize("X", [np.random.randn(10, 3), + sparse.rand(10, 3, density=0.5)]) +def test_robust_scaler_attributes(X, with_centering, with_scaling): + # check consistent type of attributes + if with_centering and sparse.issparse(X): + pytest.skip("RobustScaler cannot center sparse matrix") + + scaler = RobustScaler(with_centering=with_centering, + with_scaling=with_scaling) + scaler.fit(X) + + if with_centering: + assert isinstance(scaler.center_, np.ndarray) + else: + assert scaler.center_ is None + if with_scaling: + assert isinstance(scaler.scale_, np.ndarray) + else: + assert scaler.scale_ is None + + +def test_robust_scaler_col_zero_sparse(): + # check that the scaler is working when there is not data materialized in a + # column of a sparse matrix + X = np.random.randn(10, 5) + X[:, 0] = 0 + X = sparse.csr_matrix(X) + + scaler = RobustScaler(with_centering=False) + scaler.fit(X) + assert scaler.scale_[0] == pytest.approx(1) + + X_trans = scaler.transform(X) + assert_allclose(X[:, 0].toarray(), X_trans[:, 0].toarray()) + + def test_robust_scaler_2d_arrays(): # Test robust scaling of 2d array along first axis rng = np.random.RandomState(0) @@ -919,6 +965,29 @@ def test_robust_scaler_2d_arrays(): assert_array_almost_equal(X_scaled.std(axis=0)[0], 0) +@pytest.mark.parametrize("density", [0, 0.05, 0.1, 0.5, 1]) +@pytest.mark.parametrize("strictly_signed", + ['positive', 'negative', 'zeros', None]) +def test_robust_scaler_equivalence_dense_sparse(density, strictly_signed): + # Check the equivalence of the fitting with dense and sparse matrices + X_sparse = sparse.rand(1000, 5, density=density).tocsc() + if strictly_signed == 'positive': + X_sparse.data = np.abs(X_sparse.data) + elif strictly_signed == 'negative': + X_sparse.data = - np.abs(X_sparse.data) + elif strictly_signed == 'zeros': + X_sparse.data = np.zeros(X_sparse.data.shape, dtype=np.float64) + X_dense = X_sparse.toarray() + + scaler_sparse = RobustScaler(with_centering=False) + scaler_dense = RobustScaler(with_centering=False) + + scaler_sparse.fit(X_sparse) + scaler_dense.fit(X_dense) + + assert_allclose(scaler_sparse.scale_, scaler_dense.scale_) + + def test_robust_scaler_transform_one_row_csr(): # Check RobustScaler on transforming csr matrix with one row rng = np.random.RandomState(0) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 6c8fffd103d40..02d91ee80791b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -79,7 +79,7 @@ 'RandomForestRegressor', 'Ridge', 'RidgeCV'] ALLOW_NAN = ['Imputer', 'SimpleImputer', 'ChainedImputer', - 'MaxAbsScaler', 'MinMaxScaler', 'StandardScaler', + 'MaxAbsScaler', 'MinMaxScaler', 'RobustScaler', 'StandardScaler', 'PowerTransformer', 'QuantileTransformer'] diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 0117770084177..56d56aff2321d 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -280,6 +280,19 @@ def nanpercentile(a, q): from numpy import nanpercentile # noqa +if np_version < (1, 9): + def nanmedian(a, axis=None): + if axis is None: + data = a.reshape(-1) + return np.median(np.compress(~np.isnan(data), data)) + else: + data = a.T if not axis else a + return np.array([np.median(np.compress(~np.isnan(row), row)) + for row in data]) +else: + from numpy import nanmedian # noqa + + # Fix for behavior inconsistency on numpy.equal for object dtypes. # For numpy versions < 1.13, numpy.equal tests element-wise identity of objects # instead of equality. This fix returns the mask of NaNs in an array of diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index 8a55f74a4f6e3..92f954439f797 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -14,6 +14,7 @@ from sklearn.utils.fixes import divide from sklearn.utils.fixes import MaskedArray +from sklearn.utils.fixes import nanmedian from sklearn.utils.fixes import nanpercentile @@ -31,6 +32,22 @@ def test_masked_array_obj_dtype_pickleable(): assert_array_equal(marr.mask, marr_pickled.mask) +@pytest.mark.parametrize( + "axis, expected_median", + [(None, 4.0), + (0, np.array([1., 3.5, 3.5, 4., 7., np.nan])), + (1, np.array([1., 6.]))] +) +def test_nanmedian(axis, expected_median): + X = np.array([[1, 1, 1, 2, np.nan, np.nan], + [np.nan, 6, 6, 6, 7, np.nan]]) + median = nanmedian(X, axis=axis) + if axis is None: + assert median == pytest.approx(expected_median) + else: + assert_allclose(median, expected_median) + + @pytest.mark.parametrize( "a, q, expected_percentile", [(np.array([1, 2, 3, np.nan]), [0, 50, 100], np.array([1., 2., 3.])),