From fb354422cfc41b2c6a190bc2cb8490146a3fad02 Mon Sep 17 00:00:00 2001 From: Lucija Gregov Date: Sun, 22 Apr 2018 12:53:53 +0100 Subject: [PATCH 1/6] Passing NaN values through MaxAbsScaler --- sklearn/preprocessing/data.py | 12 +++-- sklearn/preprocessing/tests/test_common.py | 56 ++++++++++++++++++++++ sklearn/utils/estimator_checks.py | 2 +- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 6cb69d216266b..9890178c7db3f 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -842,7 +842,8 @@ def partial_fit(self, X, y=None): Ignored """ X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - estimator=self, dtype=FLOAT_DTYPES) + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): mins, maxs = min_max_axis(X, axis=0) @@ -872,7 +873,8 @@ def transform(self, X): """ check_is_fitted(self, 'scale_') X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - estimator=self, dtype=FLOAT_DTYPES) + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): inplace_column_scale(X, 1.0 / self.scale_) @@ -890,7 +892,8 @@ def inverse_transform(self, X): """ check_is_fitted(self, 'scale_') X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy, - estimator=self, dtype=FLOAT_DTYPES) + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): inplace_column_scale(X, self.scale_) @@ -936,7 +939,8 @@ def maxabs_scale(X, axis=0, copy=True): # If copy is required, it will be done inside the scaler object. X = check_array(X, accept_sparse=('csr', 'csc'), copy=False, - ensure_2d=False, dtype=FLOAT_DTYPES) + ensure_2d=False, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') original_ndim = X.ndim if original_ndim == 1: diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index f3b6707602f09..35d8ca5983198 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -1,10 +1,12 @@ import pytest import numpy as np +from scipy.sparse import csr_matrix from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.preprocessing import QuantileTransformer from sklearn.preprocessing import MinMaxScaler +from sklearn.preprocessing import MaxAbsScaler from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_allclose @@ -14,6 +16,7 @@ @pytest.mark.parametrize( "est", [MinMaxScaler(), + MaxAbsScaler(), QuantileTransformer(n_quantiles=10, random_state=42)] ) def test_missing_value_handling(est): @@ -28,6 +31,58 @@ def test_missing_value_handling(est): assert not np.all(np.isnan(X_train), axis=0).any() assert np.any(np.isnan(X_train), axis=0).all() assert np.any(np.isnan(X_test), axis=0).all() + + X_test[:, 0] = np.nan # make sure this boundary case is tested + + Xt = est.fit(X_train).transform(X_test) + # missing values should still be missing, and only them + assert_array_equal(np.isnan(Xt), np.isnan(X_test)) + + # check that the inverse transform keep NaN + Xt_inv = est.inverse_transform(Xt) + assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test)) + # FIXME: we can introduce equal_nan=True in recent version of numpy. + # For the moment which just check that non-NaN values are almost equal. + assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)]) + + for i in range(X.shape[1]): + # train only on non-NaN + est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])]) + # check transforming with NaN works even when training without NaN + Xt_col = est.transform(X_test[:, [i]]) + assert_array_equal(Xt_col, Xt[:, [i]]) + # check non-NaN is handled as before - the 1st column is all nan + if not np.isnan(X_test[:, i]).all(): + Xt_col_nonan = est.transform( + X_test[:, [i]][~np.isnan(X_test[:, i])]) + assert_array_equal(Xt_col_nonan, + Xt_col[~np.isnan(Xt_col.squeeze())]) + +@pytest.mark.parametrize( + "est", + [MaxAbsScaler(), + QuantileTransformer(n_quantiles=10, random_state=42)] +) +@pytest.mark.parametrize( + "sparse_format_func", + [csc_matrix, + csr_matrix] +) +def test_missing_value_handling_sparse(est, sparse_format_func): + # check that the preprocessing method let pass nan + rng = np.random.RandomState(42) + X = iris.data.copy() + n_missing = 50 + X[rng.randint(X.shape[0], size=n_missing), + rng.randint(X.shape[1], size=n_missing)] = np.nan + X_train, X_test = train_test_split(X, random_state=1) + # sanity check + assert not np.all(np.isnan(X_train), axis=0).any() + assert np.any(np.isnan(X_train), axis=0).all() + assert np.any(np.isnan(X_test), axis=0).all() + + X = sparse_format_func(X) + X_test[:, 0] = np.nan # make sure this boundary case is tested Xt = est.fit(X_train).transform(X_test) @@ -53,3 +108,4 @@ def test_missing_value_handling(est): X_test[:, [i]][~np.isnan(X_test[:, i])]) assert_array_equal(Xt_col_nonan, Xt_col[~np.isnan(Xt_col.squeeze())]) + diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 7de1175618f4c..6fa6e6499b15a 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -74,7 +74,7 @@ 'RandomForestRegressor', 'Ridge', 'RidgeCV'] ALLOW_NAN = ['Imputer', 'SimpleImputer', 'MICEImputer', - 'MinMaxScaler', 'QuantileTransformer'] + 'MinMaxScaler', 'MaxAbsScaler', 'QuantileTransformer'] def _yield_non_meta_checks(name, estimator): From fbd4049da1dc36641652e8b46e4e9ba6536aa5fb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Jun 2018 00:24:20 +0200 Subject: [PATCH 2/6] TST check maxabs_scale and MaxAbsScaler for ignoring NaN --- doc/whats_new/v0.20.rst | 5 +++++ sklearn/preprocessing/data.py | 10 ++++++++-- sklearn/preprocessing/tests/test_common.py | 7 +++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 87569d8649d86..34ac597a3023e 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -230,6 +230,11 @@ Preprocessing :issue:`10404` and :issue:`11243` by :user:`Lucija Gregov ` and :user:`Guillaume Lemaitre `. +- :class:`preprocessing.MaxAbsScaler` and :func:`preprocessingmaxabs_scale` + handles and ignores NaN values. + :issue:`11011` by `Lucija Gregov ` and + :user:`Guillaume Lemaitre ` + Model evaluation and meta-estimators - A scorer based on :func:`metrics.brier_score_loss` is also available. diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 012269ed1344f..2aaeee0a56f30 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -800,6 +800,9 @@ class MaxAbsScaler(BaseEstimator, TransformerMixin): Notes ----- + NaNs are treated as missing values: disregarded in fit, and maintained in + transform. + For a comparison of the different scalers, transformers, and normalizers, see :ref:`examples/preprocessing/plot_all_scaling.py `. @@ -855,10 +858,10 @@ def partial_fit(self, X, y=None): force_all_finite='allow-nan') if sparse.issparse(X): - mins, maxs = min_max_axis(X, axis=0) + mins, maxs = min_max_axis(X, axis=0, ignore_nan=True) max_abs = np.maximum(np.abs(mins), np.abs(maxs)) else: - max_abs = np.abs(X).max(axis=0) + max_abs = np.nanmax(np.abs(X), axis=0) # First pass if not hasattr(self, 'n_samples_seen_'): @@ -940,6 +943,9 @@ def maxabs_scale(X, axis=0, copy=True): Notes ----- + NaNs are treated as missing values: disregarded to compute the statistics, + and maintained during the data transformation. + For a comparison of the different scalers, transformers, and normalizers, see :ref:`examples/preprocessing/plot_all_scaling.py `. diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 1476b5d673148..252873cb17de5 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -8,9 +8,11 @@ from sklearn.base import clone +from sklearn.preprocessing import maxabs_scale from sklearn.preprocessing import minmax_scale from sklearn.preprocessing import quantile_transform +from sklearn.preprocessing import MaxAbsScaler from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import QuantileTransformer @@ -27,7 +29,8 @@ def _get_valid_samples_by_column(X, col): @pytest.mark.parametrize( "est, func, support_sparse", - [(MinMaxScaler(), minmax_scale, False), + [(MaxAbsScaler(), maxabs_scale, True), + (MinMaxScaler(), minmax_scale, False), (QuantileTransformer(n_quantiles=10), quantile_transform, True)] ) def test_missing_value_handling(est, func, support_sparse): @@ -89,4 +92,4 @@ def test_missing_value_handling(est, func, support_sparse): .transform(sparse_constructor(X_test))) assert_allclose(Xt_sparse.A, Xt_dense) Xt_inv_sparse = est_sparse.inverse_transform(Xt_sparse) - assert_allclose(Xt_inv_sparse.A, Xt_inv_dense) \ No newline at end of file + assert_allclose(Xt_inv_sparse.A, Xt_inv_dense) From 7909057a3b3852ecd2707714625cb8096bc92c9a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Jun 2018 00:29:38 +0200 Subject: [PATCH 3/6] fix spelling --- doc/whats_new/v0.20.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 34ac597a3023e..dca639472612f 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -230,7 +230,7 @@ Preprocessing :issue:`10404` and :issue:`11243` by :user:`Lucija Gregov ` and :user:`Guillaume Lemaitre `. -- :class:`preprocessing.MaxAbsScaler` and :func:`preprocessingmaxabs_scale` +- :class:`preprocessing.MaxAbsScaler` and :func:`preprocessing.maxabs_scale` handles and ignores NaN values. :issue:`11011` by `Lucija Gregov ` and :user:`Guillaume Lemaitre ` From 49e23e4781f284d031bc0ce7f1bba8198f818cf6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 17 Jun 2018 12:38:40 +0200 Subject: [PATCH 4/6] TST check that we do not raise warnings --- sklearn/preprocessing/data.py | 3 +- sklearn/preprocessing/tests/test_common.py | 35 +++++++++++++++------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 2aaeee0a56f30..cd76005df0298 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2380,7 +2380,8 @@ def _transform_col(self, X_col, quantiles, inverse): lower_bound_y = quantiles[0] upper_bound_y = quantiles[-1] # for inverse transform, match a uniform PDF - X_col = output_distribution.cdf(X_col) + with np.errstate(invalid='ignore'): # hide NaN comparison warnings + X_col = output_distribution.cdf(X_col) # find index for lower and higher bounds with np.errstate(invalid='ignore'): # hide NaN comparison warnings lower_bounds_idx = (X_col - BOUNDS_THRESHOLD < diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 252873cb17de5..9d65782b57458 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -47,12 +47,17 @@ def test_missing_value_handling(est, func, support_sparse): assert np.any(np.isnan(X_test), axis=0).all() X_test[:, 0] = np.nan # make sure this boundary case is tested - Xt = est.fit(X_train).transform(X_test) + with pytest.warns(None) as records: + Xt = est.fit(X_train).transform(X_test) + # ensure no warnings are raised + assert len(records) == 0 # missing values should still be missing, and only them assert_array_equal(np.isnan(Xt), np.isnan(X_test)) # check that the function leads to the same results as the class - Xt_class = est.transform(X_train) + with pytest.warns(None) as records: + Xt_class = est.transform(X_train) + assert len(records) == 0 Xt_func = func(X_train, **est.get_params()) assert_array_equal(np.isnan(Xt_func), np.isnan(Xt_class)) assert_allclose(Xt_func[~np.isnan(Xt_func)], Xt_class[~np.isnan(Xt_class)]) @@ -68,7 +73,9 @@ def test_missing_value_handling(est, func, support_sparse): # train only on non-NaN est.fit(_get_valid_samples_by_column(X_train, i)) # check transforming with NaN works even when training without NaN - Xt_col = est.transform(X_test[:, [i]]) + with pytest.warns(None) as records: + Xt_col = est.transform(X_test[:, [i]]) + assert len(records) == 0 assert_array_equal(Xt_col, Xt[:, [i]]) # check non-NaN is handled as before - the 1st column is all nan if not np.isnan(X_test[:, i]).all(): @@ -81,15 +88,23 @@ def test_missing_value_handling(est, func, support_sparse): est_dense = clone(est) est_sparse = clone(est) - Xt_dense = est_dense.fit(X_train).transform(X_test) - Xt_inv_dense = est_dense.inverse_transform(Xt_dense) + with pytest.warns(None) as records: + Xt_dense = est_dense.fit(X_train).transform(X_test) + Xt_inv_dense = est_dense.inverse_transform(Xt_dense) + assert len(records) == 0 for sparse_constructor in (sparse.csr_matrix, sparse.csc_matrix, sparse.bsr_matrix, sparse.coo_matrix, sparse.dia_matrix, sparse.dok_matrix, sparse.lil_matrix): # check that the dense and sparse inputs lead to the same results - Xt_sparse = (est_sparse.fit(sparse_constructor(X_train)) - .transform(sparse_constructor(X_test))) - assert_allclose(Xt_sparse.A, Xt_dense) - Xt_inv_sparse = est_sparse.inverse_transform(Xt_sparse) - assert_allclose(Xt_inv_sparse.A, Xt_inv_dense) + # precompute the matrix to avoid catching side warnings + X_train_sp = sparse_constructor(X_train) + X_test_sp = sparse_constructor(X_test) + with pytest.warns(None) as records: + Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp) + assert len(records) == 0 + assert_allclose(Xt_sp.A, Xt_dense) + with pytest.warns(None) as records: + Xt_inv_sp = est_sparse.inverse_transform(Xt_sp) + assert len(records) == 0 + assert_allclose(Xt_inv_sp.A, Xt_inv_dense) From 2497932eb4c0b700bda0c474eea273e5daa0df8c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 21 Jun 2018 17:07:21 +0200 Subject: [PATCH 5/6] FIX: avoid warning raising with some nan division --- sklearn/utils/extmath.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index f049430d23a65..218733145a0de 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -710,7 +710,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count): new_unnormalized_variance = np.nanvar(X, axis=0) * new_sample_count last_unnormalized_variance = last_variance * last_sample_count - with np.errstate(divide='ignore'): + with np.errstate(divide='ignore', invalid='ignore'): last_over_new_count = last_sample_count / new_sample_count updated_unnormalized_variance = ( last_unnormalized_variance + new_unnormalized_variance + From 9668eb36985b04cfe4140ee2b4fba064386d622e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 22 Jun 2018 12:28:24 +0200 Subject: [PATCH 6/6] FIX ignore all NaN with nanmin --- sklearn/preprocessing/data.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 753ab5f5181dd..7c014a07481be 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2568,9 +2568,13 @@ def _check_input(self, X, check_positive=False, check_shape=False, X = check_array(X, ensure_2d=True, dtype=FLOAT_DTYPES, copy=self.copy, force_all_finite='allow-nan') - if check_positive and self.method == 'box-cox' and np.nanmin(X) <= 0: - raise ValueError("The Box-Cox transformation can only be applied " - "to strictly positive data") + with np.warnings.catch_warnings(): + np.warnings.filterwarnings( + 'ignore', r'All-NaN (slice|axis) encountered') + if (check_positive and self.method == 'box-cox' and + np.nanmin(X) <= 0): + raise ValueError("The Box-Cox transformation can only be " + "applied to strictly positive data") if check_shape and not X.shape[1] == len(self.lambdas_): raise ValueError("Input data has a different number of features "