From cc3bb96211acc5b1bde14a54d0c25ee85ed85f9c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 9 Jan 2018 17:19:47 +0100 Subject: [PATCH 01/34] EHN handle NaN value in QuantileTransformer --- sklearn/preprocessing/data.py | 24 +++++++++++++++++++++--- sklearn/preprocessing/tests/test_data.py | 22 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index edadaa784ff1d..3274b013531fa 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2217,7 +2217,7 @@ def _dense_fit(self, X, random_state): size=self.subsample, replace=False) col = col.take(subsample_idx, mode='clip') - self.quantiles_.append(np.percentile(col, references)) + self.quantiles_.append(np.nanpercentile(col, references)) self.quantiles_ = np.transpose(self.quantiles_) def _sparse_fit(self, X, random_state): @@ -2262,7 +2262,7 @@ def _sparse_fit(self, X, random_state): self.quantiles_.append([0] * len(references)) else: self.quantiles_.append( - np.percentile(column_data, references)) + np.nanpercentile(column_data, references)) self.quantiles_ = np.transpose(self.quantiles_) def fit(self, X, y=None): @@ -2334,6 +2334,9 @@ def _transform_col(self, X_col, quantiles, inverse): # for inverse transform, match a uniform PDF X_col = output_distribution.cdf(X_col) # find index for lower and higher bounds + # FIXME: NaN will raise a RuntimeWarning in the following + # comparison. Comparison with NaN will return False which is also the + # behavior that we want. lower_bounds_idx = (X_col - BOUNDS_THRESHOLD < lower_bound_x) upper_bounds_idx = (X_col + BOUNDS_THRESHOLD > @@ -2369,10 +2372,25 @@ def _transform_col(self, X_col, quantiles, inverse): return X_col + @staticmethod + def _assert_finite_or_nan(X): + """Check that X contain finite or NaN values.""" + X = np.asanyarray(X) + # First try an O(n) time, O(1) space solution for the common case that + # everything is finite; fall back to O(n) space np.isfinite to prevent + # false positives from overflow in sum method. + if (X.dtype.char in np.typecodes['AllFloat'] + and not np.isfinite(X[~np.isnan(X)].sum()) + and not np.isfinite(X[~np.isnan(X)]).all()): + raise ValueError("Input contains infinity" + " or a value too large for %r." % X.dtype) + def _check_inputs(self, X, accept_sparse_negative=False): """Check inputs before fit and transform""" X = check_array(X, accept_sparse='csc', copy=self.copy, - dtype=[np.float64, np.float32]) + dtype=[np.float64, np.float32], force_all_finite=False) + # we accept nan values but not infinite values. + self._assert_finite_or_nan(X.data if sparse.issparse(X) else X) # we only accept positive sparse matrix when ignore_implicit_zeros is # false and that we call fit or transform. if (not accept_sparse_negative and not self.ignore_implicit_zeros and diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index abb17142efc77..fd0655b2035d4 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -944,6 +944,28 @@ def test_quantile_transform_check_error(): transformer.transform, 10) +def test_quantile_transform_nan(): + X = np.array([[0, 1], + [0, 0], + [np.nan, 2], + [0, np.nan], + [0, 1]]) + X_sparse = sparse.csc_matrix(X) + + transformer = QuantileTransformer(n_quantiles=5) + X_expected = np.array([[0, 0.5], + [0, 0], + [np.nan, 1], + [0, np.nan], + [0, 0.5]]) + + X_trans = transformer.fit_transform(X) + assert_almost_equal(X_expected, X_trans) + + X_trans = transformer.fit_transform(X_sparse) + assert_almost_equal(X_expected, X_trans.A) + + def test_quantile_transform_sparse_ignore_zeros(): X = np.array([[0, 1], [0, 0], From 76123c89e37f1face776e1667b824497feb294e0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 9 Jan 2018 17:23:14 +0100 Subject: [PATCH 02/34] DOC add whats new entry --- doc/whats_new/v0.20.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index a2d5e25c6a211..dcce83845d49c 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -72,6 +72,8 @@ Preprocessing :issue:`10210` by :user:`Eric Chang ` and :user:`Maniteja Nandana `. +- :class:`preprocessing.QuantileTransformer` handles and ignores NaN values. + :issue:`10404` by :user:`Guillaume Lemaitre `. Model evaluation From 530c7bf8ce1b15ab500e46d186d22ee25a9acd5e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 9 Jan 2018 18:17:47 +0100 Subject: [PATCH 03/34] TST relax inf/nan common test --- sklearn/utils/estimator_checks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 398c12cbddb42..66def1d6dc583 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -70,6 +70,7 @@ 'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression', 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] +ACCEPT_NAN = ['QuantileTransformer'] def _yield_non_meta_checks(name, estimator): @@ -971,6 +972,8 @@ def check_estimators_nan_inf(name, estimator_orig): error_string_transform = ("Estimator doesn't check for NaN and inf in" " transform.") for X_train in [X_train_nan, X_train_inf]: + if np.any(np.isnan(X_train)) and name in ACCEPT_NAN: + continue # catch deprecation warnings with ignore_warnings(category=(DeprecationWarning, FutureWarning)): estimator = clone(estimator_orig) From 1f079631996acf458dd5eb762ca90639a10920c3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 10 Jan 2018 16:21:11 +0100 Subject: [PATCH 04/34] FIX silent warning and raise an error for numpy version --- sklearn/preprocessing/data.py | 52 +++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 3274b013531fa..987525731f4d8 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -24,7 +24,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var -from ..utils.fixes import _argmax +from ..utils.fixes import _argmax, _parse_version from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -2217,7 +2217,7 @@ def _dense_fit(self, X, random_state): size=self.subsample, replace=False) col = col.take(subsample_idx, mode='clip') - self.quantiles_.append(np.nanpercentile(col, references)) + self.quantiles_.append(self._percentile_func(col, references)) self.quantiles_ = np.transpose(self.quantiles_) def _sparse_fit(self, X, random_state): @@ -2262,7 +2262,7 @@ def _sparse_fit(self, X, random_state): self.quantiles_.append([0] * len(references)) else: self.quantiles_.append( - np.nanpercentile(column_data, references)) + self._percentile_func(column_data, references)) self.quantiles_ = np.transpose(self.quantiles_) def fit(self, X, y=None): @@ -2334,13 +2334,13 @@ def _transform_col(self, X_col, quantiles, inverse): # for inverse transform, match a uniform PDF X_col = output_distribution.cdf(X_col) # find index for lower and higher bounds - # FIXME: NaN will raise a RuntimeWarning in the following - # comparison. Comparison with NaN will return False which is also the - # behavior that we want. - lower_bounds_idx = (X_col - BOUNDS_THRESHOLD < - lower_bound_x) - upper_bounds_idx = (X_col + BOUNDS_THRESHOLD > - upper_bound_x) + # comparison with NaN will raise a warning which we make silent + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lower_bounds_idx = (X_col - BOUNDS_THRESHOLD < + lower_bound_x) + upper_bounds_idx = (X_col + BOUNDS_THRESHOLD > + upper_bound_x) if not inverse: # Interpolate in one direction and in the other and take the @@ -2348,7 +2348,7 @@ def _transform_col(self, X_col, quantiles, inverse): # and hence repeated quantiles # # If we don't do this, only one extreme of the duplicated is - # used (the upper when we do assending, and the + # used (the upper when we do ascending, and the # lower for descending). We take the mean of these two X_col = .5 * (np.interp(X_col, quantiles, self.references_) - np.interp(-X_col, -quantiles[::-1], @@ -2360,7 +2360,10 @@ def _transform_col(self, X_col, quantiles, inverse): X_col[lower_bounds_idx] = lower_bound_y # for forward transform, match the output PDF if not inverse: - X_col = output_distribution.ppf(X_col) + # comparison with NaN will raise a warning which we make silent + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + X_col = output_distribution.ppf(X_col) # find the value to clip the data to avoid mapping to # infinity. Clip such that the inverse transform will be # consistent @@ -2372,8 +2375,7 @@ def _transform_col(self, X_col, quantiles, inverse): return X_col - @staticmethod - def _assert_finite_or_nan(X): + def _assert_finite_or_nan(self, X): """Check that X contain finite or NaN values.""" X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that @@ -2384,6 +2386,17 @@ def _assert_finite_or_nan(X): and not np.isfinite(X[~np.isnan(X)]).all()): raise ValueError("Input contains infinity" " or a value too large for %r." % X.dtype) + if np.any(np.isnan(X)): + np_version = _parse_version(np.__version__) + if np_version >= (1, 9): + self._percentile_func = np.nanpercentile + else: + raise NotImplementedError( + 'QuantileTransformer does not handle NaN value with' + ' NumPy {}. Please upgrade NumPy to 1.9.'.format( + np_version)) + else: + self._percentile_func = np.percentile def _check_inputs(self, X, accept_sparse_negative=False): """Check inputs before fit and transform""" @@ -2393,10 +2406,13 @@ def _check_inputs(self, X, accept_sparse_negative=False): self._assert_finite_or_nan(X.data if sparse.issparse(X) else X) # we only accept positive sparse matrix when ignore_implicit_zeros is # false and that we call fit or transform. - if (not accept_sparse_negative and not self.ignore_implicit_zeros and - (sparse.issparse(X) and np.any(X.data < 0))): - raise ValueError('QuantileTransformer only accepts non-negative' - ' sparse matrices.') + # comparison with NaN will raise a warning which we make silent + with warnings.catch_warnings(): + warnings.filterwarnings('ignore') + if (not accept_sparse_negative and not self.ignore_implicit_zeros + and (sparse.issparse(X) and np.any(X.data < 0))): + raise ValueError('QuantileTransformer only accepts' + ' non-negative sparse matrices.') # check the output PDF if self.output_distribution not in ('normal', 'uniform'): From 91c947e81ec5bcd4ca572a81c3ef014e70659a7f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 10 Jan 2018 16:58:33 +0100 Subject: [PATCH 05/34] TST ensure that test raise error with older numpy --- sklearn/preprocessing/data.py | 10 +++++----- sklearn/preprocessing/tests/test_data.py | 19 ++++++++++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 987525731f4d8..730b1a692cd21 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -13,6 +13,7 @@ import numbers import warnings from itertools import combinations_with_replacement as combinations_w_r +from distutils.version import LooseVersion import numpy as np from scipy import sparse @@ -24,7 +25,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var -from ..utils.fixes import _argmax, _parse_version +from ..utils.fixes import _argmax from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -2387,14 +2388,13 @@ def _assert_finite_or_nan(self, X): raise ValueError("Input contains infinity" " or a value too large for %r." % X.dtype) if np.any(np.isnan(X)): - np_version = _parse_version(np.__version__) - if np_version >= (1, 9): + if LooseVersion(np.__version__) >= '1.9': self._percentile_func = np.nanpercentile else: raise NotImplementedError( 'QuantileTransformer does not handle NaN value with' - ' NumPy {}. Please upgrade NumPy to 1.9.'.format( - np_version)) + ' NumPy {}. Please upgrade to NumPy 1.9. or higher'.format( + np.__version__)) else: self._percentile_func = np.percentile diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index fd0655b2035d4..19603fc97f501 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -881,7 +881,7 @@ def test_quantile_transform_iris(): assert_array_almost_equal(X_sparse.A, X_sparse_tran_inv.A) -def test_quantile_transform_check_error(): +def test_quantile_transform_check_error(mocker): X = np.transpose([[0, 25, 50, 0, 0, 0, 75, 0, 0, 100], [2, 4, 0, 0, 6, 8, 0, 10, 0, 0], [0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]]) @@ -943,8 +943,25 @@ def test_quantile_transform_check_error(): 'Expected 2D array, got scalar array instead', transformer.transform, 10) + # check that an error is raised when NumPy is < 1.9 + mocker.patch('sklearn.preprocessing.data.LooseVersion', + return_value=LooseVersion('1.8.2')) + X_nan = np.array([[0, 1], + [0, 0], + [np.nan, np.nan], + [0, 2], + [0, 1]]) + transformer = QuantileTransformer() + assert_raises_regex(NotImplementedError, "Please upgrade to NumPy", + transformer.fit_transform, X_nan) + def test_quantile_transform_nan(): + # skip if numpy < 1.9 + if LooseVersion(np.__version__) < LooseVersion('1.9'): + raise SkipTest( + 'NumPy < 1.9 do not implement nanpercentile. Skipping test.') + X = np.array([[0, 1], [0, 0], [np.nan, 2], From 1c406c0f84ee92835bbbaed32b8b0bb6544d1310 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 10 Jan 2018 18:04:20 +0100 Subject: [PATCH 06/34] TST remove mocking --- sklearn/preprocessing/tests/test_data.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 19603fc97f501..b18d55f8e0204 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -881,7 +881,7 @@ def test_quantile_transform_iris(): assert_array_almost_equal(X_sparse.A, X_sparse_tran_inv.A) -def test_quantile_transform_check_error(mocker): +def test_quantile_transform_check_error(): X = np.transpose([[0, 25, 50, 0, 0, 0, 75, 0, 0, 100], [2, 4, 0, 0, 6, 8, 0, 10, 0, 0], [0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]]) @@ -944,16 +944,15 @@ def test_quantile_transform_check_error(mocker): transformer.transform, 10) # check that an error is raised when NumPy is < 1.9 - mocker.patch('sklearn.preprocessing.data.LooseVersion', - return_value=LooseVersion('1.8.2')) - X_nan = np.array([[0, 1], - [0, 0], - [np.nan, np.nan], - [0, 2], - [0, 1]]) - transformer = QuantileTransformer() - assert_raises_regex(NotImplementedError, "Please upgrade to NumPy", - transformer.fit_transform, X_nan) + if LooseVersion(np.__version__) < LooseVersion('1.9'): + X_nan = np.array([[0, 1], + [0, 0], + [np.nan, np.nan], + [0, 2], + [0, 1]]) + transformer = QuantileTransformer() + assert_raises_regex(NotImplementedError, "Please upgrade to NumPy", + transformer.fit_transform, X_nan) def test_quantile_transform_nan(): From ecc5048d2b31b37586ae27cd20e48610b8c9c699 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 11 Jan 2018 15:49:33 +0100 Subject: [PATCH 07/34] EHN accept integer as missing values --- sklearn/preprocessing/data.py | 25 ++++++++++++++++++++++-- sklearn/preprocessing/tests/test_data.py | 20 ++++++++++++------- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 730b1a692cd21..da761ce88f317 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1,3 +1,5 @@ +# coding: utf-8 + # Authors: Alexandre Gramfort # Mathieu Blondel # Olivier Grisel @@ -34,6 +36,7 @@ from ..utils.validation import (check_is_fitted, check_random_state, FLOAT_DTYPES) from .label import LabelEncoder +from .imputation import _get_mask BOUNDS_THRESHOLD = 1e-7 @@ -2150,6 +2153,13 @@ class QuantileTransformer(BaseEstimator, TransformerMixin): Set to False to perform inplace transformation and avoid a copy (if the input is already a numpy array). + missing_values : int or "NaN", optional, (default="NaN) + The placeholder for the missing values. All occurrences of + missing_values will be preserved. For missing values encoded as np.nan, + use the string value “NaN”. + + .. versionadded: 0.20 + Attributes ---------- quantiles_ : ndarray, shape (n_quantiles, n_features) @@ -2187,13 +2197,14 @@ class QuantileTransformer(BaseEstimator, TransformerMixin): def __init__(self, n_quantiles=1000, output_distribution='uniform', ignore_implicit_zeros=False, subsample=int(1e5), - random_state=None, copy=True): + random_state=None, copy=True, missing_values='NaN'): self.n_quantiles = n_quantiles self.output_distribution = output_distribution self.ignore_implicit_zeros = ignore_implicit_zeros self.subsample = subsample self.random_state = random_state self.copy = copy + self.missing_values = missing_values def _dense_fit(self, X, random_state): """Compute percentiles for dense matrices. @@ -2305,8 +2316,10 @@ def fit(self, X, y=None): self.references_ = np.linspace(0, 1, self.n_quantiles, endpoint=True) if sparse.issparse(X): + X.data[self._mask_missing_values] = np.nan self._sparse_fit(X, rng) else: + X[self._mask_missing_values] = np.nan self._dense_fit(X, rng) return self @@ -2387,7 +2400,7 @@ def _assert_finite_or_nan(self, X): and not np.isfinite(X[~np.isnan(X)]).all()): raise ValueError("Input contains infinity" " or a value too large for %r." % X.dtype) - if np.any(np.isnan(X)): + if np.count_nonzero(self._mask_missing_values): if LooseVersion(np.__version__) >= '1.9': self._percentile_func = np.nanpercentile else: @@ -2400,6 +2413,10 @@ def _assert_finite_or_nan(self, X): def _check_inputs(self, X, accept_sparse_negative=False): """Check inputs before fit and transform""" + if sparse.issparse(X): + self._mask_missing_values = _get_mask(X.data, self.missing_values) + else: + self._mask_missing_values = _get_mask(X, self.missing_values) X = check_array(X, accept_sparse='csc', copy=self.copy, dtype=[np.float64, np.float32], force_all_finite=False) # we accept nan values but not infinite values. @@ -2451,17 +2468,21 @@ def _transform(self, X, inverse=False): """ if sparse.issparse(X): + X.data[self._mask_missing_values] = np.nan for feature_idx in range(X.shape[1]): column_slice = slice(X.indptr[feature_idx], X.indptr[feature_idx + 1]) X.data[column_slice] = self._transform_col( X.data[column_slice], self.quantiles_[:, feature_idx], inverse) + X.data[self._mask_missing_values] = self.missing_values else: + X[self._mask_missing_values] = np.nan for feature_idx in range(X.shape[1]): X[:, feature_idx] = self._transform_col( X[:, feature_idx], self.quantiles_[:, feature_idx], inverse) + X[self._mask_missing_values] = self.missing_values return X diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index b18d55f8e0204..6c27deaf68328 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -11,6 +11,7 @@ import numpy.linalg as la from scipy import sparse, stats from distutils.version import LooseVersion +import pytest from sklearn.utils import gen_batches @@ -955,7 +956,11 @@ def test_quantile_transform_check_error(): transformer.fit_transform, X_nan) -def test_quantile_transform_nan(): +@pytest.mark.parametrize( + "missing_values, dtype", + [(np.nan, np.float64), + (100, np.int64)]) +def test_quantile_transform_missing_values(missing_values, dtype): # skip if numpy < 1.9 if LooseVersion(np.__version__) < LooseVersion('1.9'): raise SkipTest( @@ -963,16 +968,17 @@ def test_quantile_transform_nan(): X = np.array([[0, 1], [0, 0], - [np.nan, 2], - [0, np.nan], - [0, 1]]) + [missing_values, 2], + [0, missing_values], + [0, 1]], dtype=dtype) X_sparse = sparse.csc_matrix(X) - transformer = QuantileTransformer(n_quantiles=5) + transformer = QuantileTransformer(n_quantiles=5, + missing_values=missing_values) X_expected = np.array([[0, 0.5], [0, 0], - [np.nan, 1], - [0, np.nan], + [missing_values, 1], + [0, missing_values], [0, 0.5]]) X_trans = transformer.fit_transform(X) From 965811faddb45ea560c224cac03e9b365857b8eb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jan 2018 12:18:29 +0100 Subject: [PATCH 08/34] address joel comments --- sklearn/preprocessing/data.py | 76 +++++++++++++++--------- sklearn/preprocessing/tests/test_data.py | 16 ----- sklearn/utils/fixes.py | 7 +++ sklearn/utils/tests/test_fixes.py | 10 ++++ 4 files changed, 66 insertions(+), 43 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index da761ce88f317..d5b2ed0133d2c 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -15,7 +15,6 @@ import numbers import warnings from itertools import combinations_with_replacement as combinations_w_r -from distutils.version import LooseVersion import numpy as np from scipy import sparse @@ -27,7 +26,7 @@ from ..utils import check_array from ..utils.extmath import row_norms from ..utils.extmath import _incremental_mean_and_var -from ..utils.fixes import _argmax +from ..utils.fixes import _argmax, nanpercentile from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1, inplace_csr_row_normalize_l2) from ..utils.sparsefuncs import (inplace_column_scale, @@ -36,8 +35,6 @@ from ..utils.validation import (check_is_fitted, check_random_state, FLOAT_DTYPES) from .label import LabelEncoder -from .imputation import _get_mask - BOUNDS_THRESHOLD = 1e-7 @@ -2206,6 +2203,17 @@ def __init__(self, n_quantiles=1000, output_distribution='uniform', self.copy = copy self.missing_values = missing_values + @staticmethod + def _nanpercentile_force_finite(a, q): + """Force the output of nanpercentile to be finite.""" + with warnings.catch_warnings(): + warnings.filterwarnings('ignore') + percentile = nanpercentile(a, q) + if np.allclose(percentile, np.nan, equal_nan=True): + return np.zeros(q.size) + else: + return percentile + def _dense_fit(self, X, random_state): """Compute percentiles for dense matrices. @@ -2229,7 +2237,8 @@ def _dense_fit(self, X, random_state): size=self.subsample, replace=False) col = col.take(subsample_idx, mode='clip') - self.quantiles_.append(self._percentile_func(col, references)) + self.quantiles_.append( + self._nanpercentile_force_finite(col, references)) self.quantiles_ = np.transpose(self.quantiles_) def _sparse_fit(self, X, random_state): @@ -2274,9 +2283,29 @@ def _sparse_fit(self, X, random_state): self.quantiles_.append([0] * len(references)) else: self.quantiles_.append( - self._percentile_func(column_data, references)) + self._nanpercentile_force_finite(column_data, references)) self.quantiles_ = np.transpose(self.quantiles_) + @staticmethod + def _get_mask(X, value_to_mask): + """Compute a boolean mask corresponding to the missing value in a + dense array and the data of a sparse matrix.""" + if sparse.issparse(X): + data = X.data + else: + data = X + return np.isclose(data, value_to_mask, equal_nan=True) + + @staticmethod + def _apply_mask(X, mask, value): + """Apply a value to the masked value to a dense array or the data of + a sparse matrix.""" + if sparse.issparse(X): + X.data[mask] = value + else: + X[mask] = value + return X + def fit(self, X, y=None): """Compute the quantiles used for transforming. @@ -2309,17 +2338,23 @@ def fit(self, X, y=None): " and {} samples.".format(self.n_quantiles, self.subsample)) + if self.missing_values == "NaN": + self.missing_values_ = np.nan + else: + self.missing_values_ = self.missing_values + X = self._check_inputs(X) + mask_missing_values = self._get_mask(X, self.missing_values_) + X = self._apply_mask(X, mask_missing_values, np.nan) + rng = check_random_state(self.random_state) # Create the quantiles of reference self.references_ = np.linspace(0, 1, self.n_quantiles, endpoint=True) if sparse.issparse(X): - X.data[self._mask_missing_values] = np.nan self._sparse_fit(X, rng) else: - X[self._mask_missing_values] = np.nan self._dense_fit(X, rng) return self @@ -2400,25 +2435,13 @@ def _assert_finite_or_nan(self, X): and not np.isfinite(X[~np.isnan(X)]).all()): raise ValueError("Input contains infinity" " or a value too large for %r." % X.dtype) - if np.count_nonzero(self._mask_missing_values): - if LooseVersion(np.__version__) >= '1.9': - self._percentile_func = np.nanpercentile - else: - raise NotImplementedError( - 'QuantileTransformer does not handle NaN value with' - ' NumPy {}. Please upgrade to NumPy 1.9. or higher'.format( - np.__version__)) - else: - self._percentile_func = np.percentile def _check_inputs(self, X, accept_sparse_negative=False): """Check inputs before fit and transform""" - if sparse.issparse(X): - self._mask_missing_values = _get_mask(X.data, self.missing_values) - else: - self._mask_missing_values = _get_mask(X, self.missing_values) X = check_array(X, accept_sparse='csc', copy=self.copy, dtype=[np.float64, np.float32], force_all_finite=False) + # FIXME: the following blocks should be removed once #10455 is + # addressed. # we accept nan values but not infinite values. self._assert_finite_or_nan(X.data if sparse.issparse(X) else X) # we only accept positive sparse matrix when ignore_implicit_zeros is @@ -2467,24 +2490,23 @@ def _transform(self, X, inverse=False): Projected data """ + mask_missing_values = self._get_mask(X, self.missing_values_) + X = self._apply_mask(X, mask_missing_values, np.nan) + if sparse.issparse(X): - X.data[self._mask_missing_values] = np.nan for feature_idx in range(X.shape[1]): column_slice = slice(X.indptr[feature_idx], X.indptr[feature_idx + 1]) X.data[column_slice] = self._transform_col( X.data[column_slice], self.quantiles_[:, feature_idx], inverse) - X.data[self._mask_missing_values] = self.missing_values else: - X[self._mask_missing_values] = np.nan for feature_idx in range(X.shape[1]): X[:, feature_idx] = self._transform_col( X[:, feature_idx], self.quantiles_[:, feature_idx], inverse) - X[self._mask_missing_values] = self.missing_values - return X + return self._apply_mask(X, mask_missing_values, self.missing_values_) def transform(self, X): """Feature-wise transformation of the data. diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 6c27deaf68328..5eaeaa3ac3ea2 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -944,28 +944,12 @@ def test_quantile_transform_check_error(): 'Expected 2D array, got scalar array instead', transformer.transform, 10) - # check that an error is raised when NumPy is < 1.9 - if LooseVersion(np.__version__) < LooseVersion('1.9'): - X_nan = np.array([[0, 1], - [0, 0], - [np.nan, np.nan], - [0, 2], - [0, 1]]) - transformer = QuantileTransformer() - assert_raises_regex(NotImplementedError, "Please upgrade to NumPy", - transformer.fit_transform, X_nan) - @pytest.mark.parametrize( "missing_values, dtype", [(np.nan, np.float64), (100, np.int64)]) def test_quantile_transform_missing_values(missing_values, dtype): - # skip if numpy < 1.9 - if LooseVersion(np.__version__) < LooseVersion('1.9'): - raise SkipTest( - 'NumPy < 1.9 do not implement nanpercentile. Skipping test.') - X = np.array([[0, 1], [0, 0], [missing_values, 2], diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 3c81a2f86d35b..f6846b2469cbf 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -295,3 +295,10 @@ def __getstate__(self): self._fill_value) else: from numpy.ma import MaskedArray # noqa + + +if np_version < (1, 9): + def nanpercentile(a, q): + return np.percentile(np.compress(~np.isnan(a), a)) +else: + from numpy import nanpercentile diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index 7bdcfc2fc13df..b06239ad77e38 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -5,11 +5,15 @@ import pickle +import numpy as np + from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_allclose from sklearn.utils.fixes import divide from sklearn.utils.fixes import MaskedArray +from sklearn.utils.fixes import nanpercentile def test_divide(): @@ -24,3 +28,9 @@ def test_masked_array_obj_dtype_pickleable(): marr_pickled = pickle.loads(pickle.dumps(marr)) assert_array_equal(marr.data, marr_pickled.data) assert_array_equal(marr.mask, marr_pickled.mask) + + +def test_nanpercentile(): + X = np.array([0, 1, 2, np.nan]) + percentile = nanpercentile(X, [0, 50, 100]) + assert_allclose(percentile, X[~np.isnan(X)]) From cd288839dd8b36f1ceb02c463b818abcb0f3f36a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jan 2018 13:53:46 +0100 Subject: [PATCH 09/34] FIX nanpercentile for python 2 --- sklearn/utils/fixes.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index f6846b2469cbf..edf2714159ef3 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -299,6 +299,38 @@ def __getstate__(self): if np_version < (1, 9): def nanpercentile(a, q): - return np.percentile(np.compress(~np.isnan(a), a)) + """ + Compute the qth percentile of the data along the specified axis, + while ignoring nan values. + + Returns the qth percentile(s) of the array elements. + + Parameters + ---------- + a : array_like + Input array or object that can be converted to an array. + q : float in range of [0,100] (or sequence of floats) + Percentile to compute, which must be between 0 and 100 + inclusive. + + Returns + ------- + percentile : scalar or ndarray + If `q` is a single percentile and `axis=None`, then the result + is a scalar. If multiple percentiles are given, first axis of + the result corresponds to the percentiles. The other axes are + the axes that remain after the reduction of `a`. If the input + contains integers or floats smaller than ``float64``, the output + data-type is ``float64``. Otherwise, the output data-type is the + same as that of the input. If `out` is specified, that array is + returned instead. + + """ + q = np.asarray(q) + data = np.compress(~np.isnan(a), a) + if data.size: + return np.percentile(data, q) + else: + return np.array([np.nan] * q.size) else: from numpy import nanpercentile From 3d0c389dad8be7af7cd48c39219e12d4bb103ec9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jan 2018 14:05:32 +0100 Subject: [PATCH 10/34] TST test the output under numpy < 1.9 --- sklearn/utils/tests/test_fixes.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index b06239ad77e38..8a55f74a4f6e3 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -6,6 +6,7 @@ import pickle import numpy as np +import pytest from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal @@ -30,7 +31,12 @@ def test_masked_array_obj_dtype_pickleable(): assert_array_equal(marr.mask, marr_pickled.mask) -def test_nanpercentile(): - X = np.array([0, 1, 2, np.nan]) - percentile = nanpercentile(X, [0, 50, 100]) - assert_allclose(percentile, X[~np.isnan(X)]) +@pytest.mark.parametrize( + "a, q, expected_percentile", + [(np.array([1, 2, 3, np.nan]), [0, 50, 100], np.array([1., 2., 3.])), + (np.array([1, 2, 3, np.nan]), 50, 2.), + (np.array([np.nan, np.nan]), [0, 50], np.array([np.nan, np.nan]))] +) +def test_nanpercentile(a, q, expected_percentile): + percentile = nanpercentile(a, q) + assert_allclose(percentile, expected_percentile) From a217af6ab79e306a536d2eec606c40bdc3afba2a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jan 2018 15:15:00 +0100 Subject: [PATCH 11/34] FIX nanpercentile numpy 1.8 --- sklearn/preprocessing/data.py | 2 +- sklearn/utils/fixes.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index d5b2ed0133d2c..0692045503a35 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2209,7 +2209,7 @@ def _nanpercentile_force_finite(a, q): with warnings.catch_warnings(): warnings.filterwarnings('ignore') percentile = nanpercentile(a, q) - if np.allclose(percentile, np.nan, equal_nan=True): + if np.all(np.isclose(percentile, np.nan, equal_nan=True)): return np.zeros(q.size) else: return percentile diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index edf2714159ef3..d2a1a64c54401 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -326,11 +326,11 @@ def nanpercentile(a, q): returned instead. """ - q = np.asarray(q) data = np.compress(~np.isnan(a), a) if data.size: return np.percentile(data, q) else: - return np.array([np.nan] * q.size) + size_q = 1 if np.isscalar(q) else len(q) + return np.array([np.nan] * size_q) else: from numpy import nanpercentile From 85c62680565042c3bdb73afadcfd96993f401573 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 12 Jan 2018 15:16:21 +0100 Subject: [PATCH 12/34] PEP8 --- sklearn/utils/fixes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index d2a1a64c54401..66f6ebf7abb40 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -333,4 +333,4 @@ def nanpercentile(a, q): size_q = 1 if np.isscalar(q) else len(q) return np.array([np.nan] * size_q) else: - from numpy import nanpercentile + from numpy import nanpercentile # noqa From 1306992bd8d0ec03c15aed92203b52bfd9fd2791 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 13 Jan 2018 19:31:12 +0100 Subject: [PATCH 13/34] TST check all missing values behaviour --- sklearn/preprocessing/data.py | 2 +- sklearn/preprocessing/tests/test_data.py | 46 +++++++++++++----------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 0692045503a35..6e05520c78467 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2210,7 +2210,7 @@ def _nanpercentile_force_finite(a, q): warnings.filterwarnings('ignore') percentile = nanpercentile(a, q) if np.all(np.isclose(percentile, np.nan, equal_nan=True)): - return np.zeros(q.size) + return np.zeros(len(q)) else: return percentile diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 5eaeaa3ac3ea2..c38ee5d789a0e 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -950,26 +950,32 @@ def test_quantile_transform_check_error(): [(np.nan, np.float64), (100, np.int64)]) def test_quantile_transform_missing_values(missing_values, dtype): - X = np.array([[0, 1], - [0, 0], - [missing_values, 2], - [0, missing_values], - [0, 1]], dtype=dtype) - X_sparse = sparse.csc_matrix(X) - - transformer = QuantileTransformer(n_quantiles=5, - missing_values=missing_values) - X_expected = np.array([[0, 0.5], - [0, 0], - [missing_values, 1], - [0, missing_values], - [0, 0.5]]) - - X_trans = transformer.fit_transform(X) - assert_almost_equal(X_expected, X_trans) - - X_trans = transformer.fit_transform(X_sparse) - assert_almost_equal(X_expected, X_trans.A) + X_some_missing = np.array([[0, 1], + [0, 0], + [missing_values, 2], + [0, missing_values], + [0, 1]], dtype=dtype) + X_all_missing = np.array([[missing_values, missing_values], + [missing_values, missing_values]], dtype=dtype) + X_expected_some_missing = np.array([[0, 0.5], + [0, 0], + [missing_values, 1], + [0, missing_values], + [0, 0.5]]) + X_expected_all_missing = X_all_missing.copy() + + for X, X_expected in zip([X_some_missing, X_all_missing], + [X_expected_some_missing, + X_expected_all_missing]): + transformer = QuantileTransformer(n_quantiles=5, + missing_values=missing_values) + + X_trans = transformer.fit_transform(X) + assert_almost_equal(X_expected, X_trans) + + X_sparse = sparse.csc_matrix(X) + X_trans = transformer.fit_transform(X_sparse) + assert_almost_equal(X_expected, X_trans.A) def test_quantile_transform_sparse_ignore_zeros(): From 73eed7bd11cc60efc80e0224a14d9795607a152e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 13 Jan 2018 19:33:05 +0100 Subject: [PATCH 14/34] TST change name for consistency --- sklearn/utils/estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 66def1d6dc583..072bf24a91a48 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -70,7 +70,7 @@ 'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression', 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] -ACCEPT_NAN = ['QuantileTransformer'] +ALLOW_NAN = ['QuantileTransformer'] def _yield_non_meta_checks(name, estimator): @@ -972,7 +972,7 @@ def check_estimators_nan_inf(name, estimator_orig): error_string_transform = ("Estimator doesn't check for NaN and inf in" " transform.") for X_train in [X_train_nan, X_train_inf]: - if np.any(np.isnan(X_train)) and name in ACCEPT_NAN: + if np.any(np.isnan(X_train)) and name in ALLOW_NAN: continue # catch deprecation warnings with ignore_warnings(category=(DeprecationWarning, FutureWarning)): From d7b6cd9b8004ebeb69ea05a48766703310ec8d78 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Feb 2018 19:02:30 +0100 Subject: [PATCH 15/34] EHN only accept NaN for the moment --- sklearn/preprocessing/data.py | 65 ++---------------------- sklearn/preprocessing/tests/test_data.py | 6 +-- 2 files changed, 6 insertions(+), 65 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index d27b90b937f3a..50fa3798ca290 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -1,5 +1,3 @@ -# coding: utf-8 - # Authors: Alexandre Gramfort # Mathieu Blondel # Olivier Grisel @@ -2165,13 +2163,6 @@ class QuantileTransformer(BaseEstimator, TransformerMixin): Set to False to perform inplace transformation and avoid a copy (if the input is already a numpy array). - missing_values : int or "NaN", optional, (default="NaN) - The placeholder for the missing values. All occurrences of - missing_values will be preserved. For missing values encoded as np.nan, - use the string value “NaN”. - - .. versionadded: 0.20 - Attributes ---------- quantiles_ : ndarray, shape (n_quantiles, n_features) @@ -2209,14 +2200,13 @@ class QuantileTransformer(BaseEstimator, TransformerMixin): def __init__(self, n_quantiles=1000, output_distribution='uniform', ignore_implicit_zeros=False, subsample=int(1e5), - random_state=None, copy=True, missing_values='NaN'): + random_state=None, copy=True): self.n_quantiles = n_quantiles self.output_distribution = output_distribution self.ignore_implicit_zeros = ignore_implicit_zeros self.subsample = subsample self.random_state = random_state self.copy = copy - self.missing_values = missing_values @staticmethod def _nanpercentile_force_finite(a, q): @@ -2301,26 +2291,6 @@ def _sparse_fit(self, X, random_state): self._nanpercentile_force_finite(column_data, references)) self.quantiles_ = np.transpose(self.quantiles_) - @staticmethod - def _get_mask(X, value_to_mask): - """Compute a boolean mask corresponding to the missing value in a - dense array and the data of a sparse matrix.""" - if sparse.issparse(X): - data = X.data - else: - data = X - return np.isclose(data, value_to_mask, equal_nan=True) - - @staticmethod - def _apply_mask(X, mask, value): - """Apply a value to the masked value to a dense array or the data of - a sparse matrix.""" - if sparse.issparse(X): - X.data[mask] = value - else: - X[mask] = value - return X - def fit(self, X, y=None): """Compute the quantiles used for transforming. @@ -2352,15 +2322,7 @@ def fit(self, X, y=None): " and {} samples.".format(self.n_quantiles, self.subsample)) - if self.missing_values == "NaN": - self.missing_values_ = np.nan - else: - self.missing_values_ = self.missing_values - X = self._check_inputs(X) - mask_missing_values = self._get_mask(X, self.missing_values_) - X = self._apply_mask(X, mask_missing_values, np.nan) - rng = check_random_state(self.random_state) # Create the quantiles of reference @@ -2438,26 +2400,11 @@ def _transform_col(self, X_col, quantiles, inverse): return X_col - def _assert_finite_or_nan(self, X): - """Check that X contain finite or NaN values.""" - X = np.asanyarray(X) - # First try an O(n) time, O(1) space solution for the common case that - # everything is finite; fall back to O(n) space np.isfinite to prevent - # false positives from overflow in sum method. - if (X.dtype.char in np.typecodes['AllFloat'] - and not np.isfinite(X[~np.isnan(X)].sum()) - and not np.isfinite(X[~np.isnan(X)]).all()): - raise ValueError("Input contains infinity" - " or a value too large for %r." % X.dtype) - def _check_inputs(self, X, accept_sparse_negative=False): """Check inputs before fit and transform""" X = check_array(X, accept_sparse='csc', copy=self.copy, - dtype=[np.float64, np.float32], force_all_finite=False) - # FIXME: the following blocks should be removed once #10455 is - # addressed. - # we accept nan values but not infinite values. - self._assert_finite_or_nan(X.data if sparse.issparse(X) else X) + dtype=[np.float64, np.float32], + force_all_finite='allow-nan') # we only accept positive sparse matrix when ignore_implicit_zeros is # false and that we call fit or transform. # comparison with NaN will raise a warning which we make silent @@ -2503,10 +2450,6 @@ def _transform(self, X, inverse=False): X : ndarray, shape (n_samples, n_features) Projected data """ - - mask_missing_values = self._get_mask(X, self.missing_values_) - X = self._apply_mask(X, mask_missing_values, np.nan) - if sparse.issparse(X): for feature_idx in range(X.shape[1]): column_slice = slice(X.indptr[feature_idx], @@ -2520,7 +2463,7 @@ def _transform(self, X, inverse=False): X[:, feature_idx], self.quantiles_[:, feature_idx], inverse) - return self._apply_mask(X, mask_missing_values, self.missing_values_) + return X def transform(self, X): """Feature-wise transformation of the data. diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f4cbc22148705..1de9383de2db7 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -970,8 +970,7 @@ def test_quantile_transform_check_error(): @pytest.mark.parametrize( "missing_values, dtype", - [(np.nan, np.float64), - (100, np.int64)]) + [(np.nan, np.float64)]) def test_quantile_transform_missing_values(missing_values, dtype): X_some_missing = np.array([[0, 1], [0, 0], @@ -990,8 +989,7 @@ def test_quantile_transform_missing_values(missing_values, dtype): for X, X_expected in zip([X_some_missing, X_all_missing], [X_expected_some_missing, X_expected_all_missing]): - transformer = QuantileTransformer(n_quantiles=5, - missing_values=missing_values) + transformer = QuantileTransformer(n_quantiles=5) X_trans = transformer.fit_transform(X) assert_almost_equal(X_expected, X_trans) From f7bc642549fc5dc56adf6778c97c6470a691d617 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Feb 2018 19:05:27 +0100 Subject: [PATCH 16/34] unecessary change --- sklearn/preprocessing/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 50fa3798ca290..66e1826eaca2b 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -34,6 +34,7 @@ FLOAT_DTYPES) from .label import LabelEncoder + BOUNDS_THRESHOLD = 1e-7 From 84682b75546dd3ff7d1dcc5bc3fce25eb95cd58c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Feb 2018 19:06:14 +0100 Subject: [PATCH 17/34] unecessary change --- sklearn/preprocessing/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 66e1826eaca2b..d61ab8bf34d5a 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2451,6 +2451,7 @@ def _transform(self, X, inverse=False): X : ndarray, shape (n_samples, n_features) Projected data """ + if sparse.issparse(X): for feature_idx in range(X.shape[1]): column_slice = slice(X.indptr[feature_idx], From 59dfdbef98d03b6fde8c304d5a5eefb3d17a2720 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Feb 2018 20:28:24 +0100 Subject: [PATCH 18/34] solve issue in numpy 1.8 --- sklearn/preprocessing/data.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index d61ab8bf34d5a..d3a9670c87db3 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2368,6 +2368,7 @@ def _transform_col(self, X_col, quantiles, inverse): upper_bounds_idx = (X_col + BOUNDS_THRESHOLD > upper_bound_x) + mask_finite = np.isnan(X_col) if not inverse: # Interpolate in one direction and in the other and take the # mean. This is in case of repeated values in the features @@ -2376,11 +2377,13 @@ def _transform_col(self, X_col, quantiles, inverse): # If we don't do this, only one extreme of the duplicated is # used (the upper when we do ascending, and the # lower for descending). We take the mean of these two - X_col = .5 * (np.interp(X_col, quantiles, self.references_) - - np.interp(-X_col, -quantiles[::-1], - -self.references_[::-1])) + X_col[~mask_finite] = .5 * ( + np.interp(X_col[~mask_finite], quantiles, self.references_) + - np.interp(-X_col[~mask_finite], -quantiles[::-1], + -self.references_[::-1])) else: - X_col = np.interp(X_col, self.references_, quantiles) + X_col[~mask_finite] = np.interp(X_col[~mask_finite], + self.references_, quantiles) X_col[upper_bounds_idx] = upper_bound_y X_col[lower_bounds_idx] = lower_bound_y From d85220692be5561c9367e72998300db88a9cb8c7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 13 Feb 2018 17:30:15 +0100 Subject: [PATCH 19/34] address ogrisel comments --- sklearn/preprocessing/data.py | 15 ++++++++------- sklearn/preprocessing/tests/test_data.py | 21 ++++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index d3a9670c87db3..ecf7b6862a353 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2216,7 +2216,7 @@ def _nanpercentile_force_finite(a, q): warnings.filterwarnings('ignore') percentile = nanpercentile(a, q) if np.all(np.isclose(percentile, np.nan, equal_nan=True)): - return np.zeros(len(q)) + return np.zeros(len(q), dtype=a.dtype) else: return percentile @@ -2368,7 +2368,8 @@ def _transform_col(self, X_col, quantiles, inverse): upper_bounds_idx = (X_col + BOUNDS_THRESHOLD > upper_bound_x) - mask_finite = np.isnan(X_col) + isfinite_mask = ~np.isnan(X_col) + X_col_finite = X_col[isfinite_mask] if not inverse: # Interpolate in one direction and in the other and take the # mean. This is in case of repeated values in the features @@ -2377,13 +2378,13 @@ def _transform_col(self, X_col, quantiles, inverse): # If we don't do this, only one extreme of the duplicated is # used (the upper when we do ascending, and the # lower for descending). We take the mean of these two - X_col[~mask_finite] = .5 * ( - np.interp(X_col[~mask_finite], quantiles, self.references_) - - np.interp(-X_col[~mask_finite], -quantiles[::-1], + X_col[isfinite_mask] = .5 * ( + np.interp(X_col_finite, quantiles, self.references_) + - np.interp(-X_col_finite, -quantiles[::-1], -self.references_[::-1])) else: - X_col[~mask_finite] = np.interp(X_col[~mask_finite], - self.references_, quantiles) + X_col[isfinite_mask] = np.interp(X_col_finite, + self.references_, quantiles) X_col[upper_bounds_idx] = upper_bound_y X_col[lower_bounds_idx] = lower_bound_y diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 1de9383de2db7..cbf55916b869f 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -969,20 +969,21 @@ def test_quantile_transform_check_error(): @pytest.mark.parametrize( - "missing_values, dtype", - [(np.nan, np.float64)]) -def test_quantile_transform_missing_values(missing_values, dtype): + "missing_value, dtype", + [(np.nan, np.float64), + (np.nan, np.float32)]) +def test_quantile_transform_missing_value(missing_value, dtype): X_some_missing = np.array([[0, 1], [0, 0], - [missing_values, 2], - [0, missing_values], + [missing_value, 2], + [0, missing_value], [0, 1]], dtype=dtype) - X_all_missing = np.array([[missing_values, missing_values], - [missing_values, missing_values]], dtype=dtype) + X_all_missing = np.array([[missing_value, missing_value], + [missing_value, missing_value]], dtype=dtype) X_expected_some_missing = np.array([[0, 0.5], [0, 0], - [missing_values, 1], - [0, missing_values], + [missing_value, 1], + [0, missing_value], [0, 0.5]]) X_expected_all_missing = X_all_missing.copy() @@ -998,6 +999,8 @@ def test_quantile_transform_missing_values(missing_values, dtype): X_trans = transformer.fit_transform(X_sparse) assert_almost_equal(X_expected, X_trans.A) + assert X_trans.dtype == dtype + def test_quantile_transform_sparse_ignore_zeros(): X = np.array([[0, 1], From a20ac59146dc7be05c8ec6d8b82a659ddb047721 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 14 Feb 2018 12:03:44 +0100 Subject: [PATCH 20/34] Address some comments --- sklearn/preprocessing/data.py | 22 ++++++++-------------- sklearn/preprocessing/tests/test_data.py | 21 ++++++++++++++++----- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index ecf7b6862a353..3da95f61c80ef 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2210,13 +2210,13 @@ def __init__(self, n_quantiles=1000, output_distribution='uniform', self.copy = copy @staticmethod - def _nanpercentile_force_finite(a, q): + def _nanpercentile_force_finite(column_data, percentiles): """Force the output of nanpercentile to be finite.""" - with warnings.catch_warnings(): - warnings.filterwarnings('ignore') - percentile = nanpercentile(a, q) + percentile = nanpercentile(column_data, percentiles) + with np.errstate(invalid='ignore'): # hide NaN comparison warnings if np.all(np.isclose(percentile, np.nan, equal_nan=True)): - return np.zeros(len(q), dtype=a.dtype) + warnings.warn("All samples in a column of X are NaN.") + return np.zeros(len(percentiles), dtype=column_data.dtype) else: return percentile @@ -2360,9 +2360,7 @@ def _transform_col(self, X_col, quantiles, inverse): # for inverse transform, match a uniform PDF X_col = output_distribution.cdf(X_col) # find index for lower and higher bounds - # comparison with NaN will raise a warning which we make silent - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + with np.errstate(invalid='ignore'): # hide NaN comparison warnings lower_bounds_idx = (X_col - BOUNDS_THRESHOLD < lower_bound_x) upper_bounds_idx = (X_col + BOUNDS_THRESHOLD > @@ -2390,9 +2388,7 @@ def _transform_col(self, X_col, quantiles, inverse): X_col[lower_bounds_idx] = lower_bound_y # for forward transform, match the output PDF if not inverse: - # comparison with NaN will raise a warning which we make silent - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + with np.errstate(invalid='ignore'): # hide NaN comparison warnings X_col = output_distribution.ppf(X_col) # find the value to clip the data to avoid mapping to # infinity. Clip such that the inverse transform will be @@ -2412,9 +2408,7 @@ def _check_inputs(self, X, accept_sparse_negative=False): force_all_finite='allow-nan') # we only accept positive sparse matrix when ignore_implicit_zeros is # false and that we call fit or transform. - # comparison with NaN will raise a warning which we make silent - with warnings.catch_warnings(): - warnings.filterwarnings('ignore') + with np.errstate(invalid='ignore'): # hide NaN comparison warnings if (not accept_sparse_negative and not self.ignore_implicit_zeros and (sparse.issparse(X) and np.any(X.data < 0))): raise ValueError('QuantileTransformer only accepts' diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index cbf55916b869f..cd5315e5d25dd 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -987,16 +987,27 @@ def test_quantile_transform_missing_value(missing_value, dtype): [0, 0.5]]) X_expected_all_missing = X_all_missing.copy() - for X, X_expected in zip([X_some_missing, X_all_missing], - [X_expected_some_missing, - X_expected_all_missing]): + for X, X_expected, all_nan in zip([X_some_missing, X_all_missing], + [X_expected_some_missing, + X_expected_all_missing], + [False, True]): transformer = QuantileTransformer(n_quantiles=5) - X_trans = transformer.fit_transform(X) + if all_nan: + X_trans = assert_warns_message(UserWarning, + "samples in a column of X are NaN", + transformer.fit_transform, X) + else: + X_trans = assert_no_warnings(transformer.fit_transform, X) assert_almost_equal(X_expected, X_trans) X_sparse = sparse.csc_matrix(X) - X_trans = transformer.fit_transform(X_sparse) + if all_nan: + X_trans = assert_warns_message(UserWarning, + "samples in a column of X are NaN", + transformer.fit_transform, X_sparse) + else: + X_trans = assert_no_warnings(transformer.fit_transform, X_sparse) assert_almost_equal(X_expected, X_trans.A) assert X_trans.dtype == dtype From f8dd6a4b4851a5a04a26de6650f72e0816ef6c8d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 2 Mar 2018 13:12:03 +0100 Subject: [PATCH 21/34] TST fix common test --- sklearn/utils/estimator_checks.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 533739b26fd38..b1a821822a58a 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -70,7 +70,7 @@ 'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression', 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] -ALLOW_NAN = ['QuantileTransformer'] +ALLOW_NAN = ['QuantileTransformer', 'Imputer'] def _yield_non_meta_checks(name, estimator): @@ -92,7 +92,7 @@ def _yield_non_meta_checks(name, estimator): # cross-decomposition's "transform" returns X and Y yield check_pipeline_consistency - if name not in ['Imputer']: + if name not in ALLOW_NAN: # Test that all estimators check their input for NaN's and infs yield check_estimators_nan_inf @@ -1025,8 +1025,6 @@ def check_estimators_nan_inf(name, estimator_orig): error_string_transform = ("Estimator doesn't check for NaN and inf in" " transform.") for X_train in [X_train_nan, X_train_inf]: - if np.any(np.isnan(X_train)) and name in ALLOW_NAN: - continue # catch deprecation warnings with ignore_warnings(category=(DeprecationWarning, FutureWarning)): estimator = clone(estimator_orig) From 8aa605940135f9d368bd500f67b2c10f7489a309 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 15 Mar 2018 14:16:51 +0100 Subject: [PATCH 22/34] TST common test for transformer letting pass nan --- sklearn/preprocessing/tests/test_data.py | 84 +++++++++++------------- 1 file changed, 39 insertions(+), 45 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index cd5315e5d25dd..b695b3f020c89 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -62,6 +62,7 @@ from sklearn.pipeline import Pipeline from sklearn.model_selection import cross_val_predict +from sklearn.model_selection import train_test_split from sklearn.svm import SVR from sklearn import datasets @@ -102,6 +103,44 @@ def assert_correct_incr(i, batch_start, batch_stop, n, chunk_size, n_samples_seen) +def _generate_tuple_transformer_missing_value(): + trans_handling_nan = [QuantileTransformer()] + return [(trans, iris.data.copy(), 15) + for trans in trans_handling_nan] + + +@pytest.mark.parametrize( + "est, X, n_missing", + _generate_tuple_transformer_missing_value() +) +def test_missing_value_handling(est, X, n_missing): + # check that the preprocessing method let pass nan + rng = np.random.RandomState(42) + X[rng.randint(X.shape[0], size=n_missing), + rng.randint(X.shape[1], size=n_missing)] = np.nan + X_train, X_test = train_test_split(X) + # sanity check + assert not np.all(np.isnan(X_train), axis=0).any() + X_test[:, 0] = np.nan # make sure this boundary case is tested + + Xt = est.fit(X_train).transform(X_test) + # missing values should still be missing, and only them + assert_array_equal(np.isnan(Xt), np.isnan(X_test)) + + for i in range(X.shape[1]): + # train only on non-NaN + est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])]) + # check transforming with NaN works even when training without NaN + Xt_col = est.transform(X_test[:, [i]]) + assert_array_equal(Xt_col, Xt[:, [i]]) + # check non-NaN is handled as before - the 1st column is all nan + if not np.isnan(X_test[:, i]).all(): + Xt_col_nonan = est.transform( + X_test[:, [i]][~np.isnan(X_test[:, i])]) + assert_array_equal(Xt_col_nonan, + Xt_col[~np.isnan(Xt_col.squeeze())]) + + def test_polynomial_features(): # Test Polynomial Features X1 = np.arange(6)[:, np.newaxis] @@ -968,51 +1007,6 @@ def test_quantile_transform_check_error(): transformer.transform, 10) -@pytest.mark.parametrize( - "missing_value, dtype", - [(np.nan, np.float64), - (np.nan, np.float32)]) -def test_quantile_transform_missing_value(missing_value, dtype): - X_some_missing = np.array([[0, 1], - [0, 0], - [missing_value, 2], - [0, missing_value], - [0, 1]], dtype=dtype) - X_all_missing = np.array([[missing_value, missing_value], - [missing_value, missing_value]], dtype=dtype) - X_expected_some_missing = np.array([[0, 0.5], - [0, 0], - [missing_value, 1], - [0, missing_value], - [0, 0.5]]) - X_expected_all_missing = X_all_missing.copy() - - for X, X_expected, all_nan in zip([X_some_missing, X_all_missing], - [X_expected_some_missing, - X_expected_all_missing], - [False, True]): - transformer = QuantileTransformer(n_quantiles=5) - - if all_nan: - X_trans = assert_warns_message(UserWarning, - "samples in a column of X are NaN", - transformer.fit_transform, X) - else: - X_trans = assert_no_warnings(transformer.fit_transform, X) - assert_almost_equal(X_expected, X_trans) - - X_sparse = sparse.csc_matrix(X) - if all_nan: - X_trans = assert_warns_message(UserWarning, - "samples in a column of X are NaN", - transformer.fit_transform, X_sparse) - else: - X_trans = assert_no_warnings(transformer.fit_transform, X_sparse) - assert_almost_equal(X_expected, X_trans.A) - - assert X_trans.dtype == dtype - - def test_quantile_transform_sparse_ignore_zeros(): X = np.array([[0, 1], [0, 0], From c745eaba144a374176ef8f778a3eb239286b8434 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 18 Mar 2018 12:15:36 +0100 Subject: [PATCH 23/34] TST add separate commont tests --- sklearn/preprocessing/tests/test_common.py | 54 ++++++++++++++++++++++ sklearn/preprocessing/tests/test_data.py | 39 ---------------- 2 files changed, 54 insertions(+), 39 deletions(-) create mode 100644 sklearn/preprocessing/tests/test_common.py diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py new file mode 100644 index 0000000000000..7e2ffd811063f --- /dev/null +++ b/sklearn/preprocessing/tests/test_common.py @@ -0,0 +1,54 @@ +import pytest +import numpy as np + + +from sklearn.datasets import load_iris + +from sklearn.model_selection import train_test_split + +from sklearn.preprocessing import QuantileTransformer + +from sklearn.utils.testing import assert_array_equal + +iris = load_iris() + +TRANSFORMER_HANDLING_NAN = ['QuantileTransformer'] + + +def _generate_tuple_transformer_missing_value(): + trans_handling_nan = [globals()[trans_name]() + for trans_name in TRANSFORMER_HANDLING_NAN] + return [(trans, iris.data.copy(), 15) + for trans in trans_handling_nan] + + +@pytest.mark.parametrize( + "est, X, n_missing", + _generate_tuple_transformer_missing_value() +) +def test_missing_value_handling(est, X, n_missing): + # check that the preprocessing method let pass nan + rng = np.random.RandomState(42) + X[rng.randint(X.shape[0], size=n_missing), + rng.randint(X.shape[1], size=n_missing)] = np.nan + X_train, X_test = train_test_split(X) + # sanity check + assert not np.all(np.isnan(X_train), axis=0).any() + X_test[:, 0] = np.nan # make sure this boundary case is tested + + Xt = est.fit(X_train).transform(X_test) + # missing values should still be missing, and only them + assert_array_equal(np.isnan(Xt), np.isnan(X_test)) + + for i in range(X.shape[1]): + # train only on non-NaN + est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])]) + # check transforming with NaN works even when training without NaN + Xt_col = est.transform(X_test[:, [i]]) + assert_array_equal(Xt_col, Xt[:, [i]]) + # check non-NaN is handled as before - the 1st column is all nan + if not np.isnan(X_test[:, i]).all(): + Xt_col_nonan = est.transform( + X_test[:, [i]][~np.isnan(X_test[:, i])]) + assert_array_equal(Xt_col_nonan, + Xt_col[~np.isnan(Xt_col.squeeze())]) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index dfa5807c2e2ee..51c37097adca2 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -62,7 +62,6 @@ from sklearn.pipeline import Pipeline from sklearn.model_selection import cross_val_predict -from sklearn.model_selection import train_test_split from sklearn.svm import SVR from sklearn import datasets @@ -103,44 +102,6 @@ def assert_correct_incr(i, batch_start, batch_stop, n, chunk_size, n_samples_seen) -def _generate_tuple_transformer_missing_value(): - trans_handling_nan = [QuantileTransformer()] - return [(trans, iris.data.copy(), 15) - for trans in trans_handling_nan] - - -@pytest.mark.parametrize( - "est, X, n_missing", - _generate_tuple_transformer_missing_value() -) -def test_missing_value_handling(est, X, n_missing): - # check that the preprocessing method let pass nan - rng = np.random.RandomState(42) - X[rng.randint(X.shape[0], size=n_missing), - rng.randint(X.shape[1], size=n_missing)] = np.nan - X_train, X_test = train_test_split(X) - # sanity check - assert not np.all(np.isnan(X_train), axis=0).any() - X_test[:, 0] = np.nan # make sure this boundary case is tested - - Xt = est.fit(X_train).transform(X_test) - # missing values should still be missing, and only them - assert_array_equal(np.isnan(Xt), np.isnan(X_test)) - - for i in range(X.shape[1]): - # train only on non-NaN - est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])]) - # check transforming with NaN works even when training without NaN - Xt_col = est.transform(X_test[:, [i]]) - assert_array_equal(Xt_col, Xt[:, [i]]) - # check non-NaN is handled as before - the 1st column is all nan - if not np.isnan(X_test[:, i]).all(): - Xt_col_nonan = est.transform( - X_test[:, [i]][~np.isnan(X_test[:, i])]) - assert_array_equal(Xt_col_nonan, - Xt_col[~np.isnan(Xt_col.squeeze())]) - - def test_polynomial_features(): # Test Polynomial Features X1 = np.arange(6)[:, np.newaxis] From ad878fad0796606861eabe77766407ad982f09e2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 18 Mar 2018 13:09:21 +0100 Subject: [PATCH 24/34] TST improve testing --- sklearn/preprocessing/tests/test_common.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 7e2ffd811063f..d0154c3bf1428 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -12,19 +12,14 @@ iris = load_iris() -TRANSFORMER_HANDLING_NAN = ['QuantileTransformer'] - - -def _generate_tuple_transformer_missing_value(): - trans_handling_nan = [globals()[trans_name]() - for trans_name in TRANSFORMER_HANDLING_NAN] - return [(trans, iris.data.copy(), 15) - for trans in trans_handling_nan] - @pytest.mark.parametrize( - "est, X, n_missing", - _generate_tuple_transformer_missing_value() + "est", + [QuantileTransformer()] +) +@pytest.mark.parametrize( + "X, n_missing", + [(iris.data.copy(), 15)] ) def test_missing_value_handling(est, X, n_missing): # check that the preprocessing method let pass nan @@ -34,6 +29,8 @@ def test_missing_value_handling(est, X, n_missing): X_train, X_test = train_test_split(X) # sanity check assert not np.all(np.isnan(X_train), axis=0).any() + assert np.any(X_train, axis=0).all() + assert np.any(X_test, axis=0).all() X_test[:, 0] = np.nan # make sure this boundary case is tested Xt = est.fit(X_train).transform(X_test) From 6784c3b2999791aa984b4e93c428b3f99afdf891 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 18 Mar 2018 13:10:55 +0100 Subject: [PATCH 25/34] TST remove parametrization on X and n_missing --- sklearn/preprocessing/tests/test_common.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index d0154c3bf1428..ae4c904b9757c 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -17,13 +17,11 @@ "est", [QuantileTransformer()] ) -@pytest.mark.parametrize( - "X, n_missing", - [(iris.data.copy(), 15)] -) -def test_missing_value_handling(est, X, n_missing): +def test_missing_value_handling(est): # check that the preprocessing method let pass nan rng = np.random.RandomState(42) + X = iris.data.copy() + n_missing = 15 X[rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)] = np.nan X_train, X_test = train_test_split(X) From daa40daff5adb5c1a45cba2c31ef6ddf701e80ce Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Mar 2018 08:34:32 +0100 Subject: [PATCH 26/34] address joel comments --- sklearn/preprocessing/data.py | 6 ++++++ sklearn/preprocessing/tests/test_common.py | 10 +++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 5395396ef5a45..308f8f50295b2 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2195,6 +2195,9 @@ class QuantileTransformer(BaseEstimator, TransformerMixin): Notes ----- + NaNs are treated as missing values: disregarded in fit, and maintained in + transform. + For a comparison of the different scalers, transformers, and normalizers, see :ref:`examples/preprocessing/plot_all_scaling.py `. @@ -2603,6 +2606,9 @@ def quantile_transform(X, axis=0, n_quantiles=1000, Notes ----- + NaNs are treated as missing values: disregarded in fit, and maintained in + transform. + For a comparison of the different scalers, transformers, and normalizers, see :ref:`examples/preprocessing/plot_all_scaling.py `. diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index ae4c904b9757c..35a285f691000 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -1,13 +1,9 @@ import pytest import numpy as np - from sklearn.datasets import load_iris - from sklearn.model_selection import train_test_split - from sklearn.preprocessing import QuantileTransformer - from sklearn.utils.testing import assert_array_equal iris = load_iris() @@ -21,14 +17,14 @@ def test_missing_value_handling(est): # check that the preprocessing method let pass nan rng = np.random.RandomState(42) X = iris.data.copy() - n_missing = 15 + n_missing = 30 X[rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)] = np.nan X_train, X_test = train_test_split(X) # sanity check assert not np.all(np.isnan(X_train), axis=0).any() - assert np.any(X_train, axis=0).all() - assert np.any(X_test, axis=0).all() + assert np.any(np.isnan(X_train), axis=0).all() + assert np.any(np.isnan(X_test), axis=0).all() X_test[:, 0] = np.nan # make sure this boundary case is tested Xt = est.fit(X_train).transform(X_test) From 2c0ceb3f4384b36ded2c6ecc7d72686fc7168645 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Mar 2018 09:09:13 +0100 Subject: [PATCH 27/34] fix random state for the split training testing --- sklearn/preprocessing/tests/test_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 35a285f691000..3198d739c86d5 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -17,10 +17,10 @@ def test_missing_value_handling(est): # check that the preprocessing method let pass nan rng = np.random.RandomState(42) X = iris.data.copy() - n_missing = 30 + n_missing = 50 X[rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)] = np.nan - X_train, X_test = train_test_split(X) + X_train, X_test = train_test_split(X, random_state=0) # sanity check assert not np.all(np.isnan(X_train), axis=0).any() assert np.any(np.isnan(X_train), axis=0).all() From 004b0e352dbfafe0ff73ae78a1f288af58af37a3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Mar 2018 12:43:54 +0100 Subject: [PATCH 28/34] do not force percentile to be finite --- sklearn/preprocessing/data.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 308f8f50295b2..3f003d8a28c88 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2213,16 +2213,16 @@ def __init__(self, n_quantiles=1000, output_distribution='uniform', self.random_state = random_state self.copy = copy - @staticmethod - def _nanpercentile_force_finite(column_data, percentiles): - """Force the output of nanpercentile to be finite.""" - percentile = nanpercentile(column_data, percentiles) - with np.errstate(invalid='ignore'): # hide NaN comparison warnings - if np.all(np.isclose(percentile, np.nan, equal_nan=True)): - warnings.warn("All samples in a column of X are NaN.") - return np.zeros(len(percentiles), dtype=column_data.dtype) - else: - return percentile + # @staticmethod + # def _nanpercentile_force_finite(column_data, percentiles): + # """Force the output of nanpercentile to be finite.""" + # percentile = nanpercentile(column_data, percentiles) + # if np.all(np.isnan(percentile)): + # print(percentile) + # warnings.warn("All samples in a column of X are NaN.") + # return np.array([np.nan] * len(percentiles)) # np.zeros(len(percentiles), dtype=column_data.dtype) + # else: + # return percentile def _dense_fit(self, X, random_state): """Compute percentiles for dense matrices. @@ -2249,8 +2249,7 @@ def _dense_fit(self, X, random_state): size=self.subsample, replace=False) col = col.take(subsample_idx, mode='clip') - self.quantiles_.append( - self._nanpercentile_force_finite(col, references)) + self.quantiles_.append(self.nanpercentile(col, references)) self.quantiles_ = np.transpose(self.quantiles_) def _sparse_fit(self, X, random_state): @@ -2296,7 +2295,7 @@ def _sparse_fit(self, X, random_state): self.quantiles_.append([0] * len(references)) else: self.quantiles_.append( - self._nanpercentile_force_finite(column_data, references)) + self.nanpercentile(column_data, references)) self.quantiles_ = np.transpose(self.quantiles_) def fit(self, X, y=None): From 9ab77b67c5ebbb440598a138f3e4bf66dc231405 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Mar 2018 12:51:28 +0100 Subject: [PATCH 29/34] fix --- sklearn/preprocessing/data.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 3f003d8a28c88..95b91a39d98c7 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2249,7 +2249,7 @@ def _dense_fit(self, X, random_state): size=self.subsample, replace=False) col = col.take(subsample_idx, mode='clip') - self.quantiles_.append(self.nanpercentile(col, references)) + self.quantiles_.append(nanpercentile(col, references)) self.quantiles_ = np.transpose(self.quantiles_) def _sparse_fit(self, X, random_state): @@ -2294,8 +2294,7 @@ def _sparse_fit(self, X, random_state): # quantiles. Force the quantiles to be zeros. self.quantiles_.append([0] * len(references)) else: - self.quantiles_.append( - self.nanpercentile(column_data, references)) + self.quantiles_.append(nanpercentile(column_data, references)) self.quantiles_ = np.transpose(self.quantiles_) def fit(self, X, y=None): From 33cc416a0badcdebc4d99477a3668c9723ea22b4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Mar 2018 13:37:46 +0100 Subject: [PATCH 30/34] TST add test for quantile transformer --- sklearn/preprocessing/data.py | 11 ----------- sklearn/preprocessing/tests/test_common.py | 10 ++++++++-- sklearn/preprocessing/tests/test_data.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 95b91a39d98c7..681bee318016a 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2213,17 +2213,6 @@ def __init__(self, n_quantiles=1000, output_distribution='uniform', self.random_state = random_state self.copy = copy - # @staticmethod - # def _nanpercentile_force_finite(column_data, percentiles): - # """Force the output of nanpercentile to be finite.""" - # percentile = nanpercentile(column_data, percentiles) - # if np.all(np.isnan(percentile)): - # print(percentile) - # warnings.warn("All samples in a column of X are NaN.") - # return np.array([np.nan] * len(percentiles)) # np.zeros(len(percentiles), dtype=column_data.dtype) - # else: - # return percentile - def _dense_fit(self, X, random_state): """Compute percentiles for dense matrices. diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index 3198d739c86d5..a7ec2e8da242e 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -5,13 +5,14 @@ from sklearn.model_selection import train_test_split from sklearn.preprocessing import QuantileTransformer from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_allclose iris = load_iris() @pytest.mark.parametrize( "est", - [QuantileTransformer()] + [QuantileTransformer(n_quantiles=10, random_state=42)] ) def test_missing_value_handling(est): # check that the preprocessing method let pass nan @@ -20,7 +21,7 @@ def test_missing_value_handling(est): n_missing = 50 X[rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)] = np.nan - X_train, X_test = train_test_split(X, random_state=0) + X_train, X_test = train_test_split(X, random_state=1) # sanity check assert not np.all(np.isnan(X_train), axis=0).any() assert np.any(np.isnan(X_train), axis=0).all() @@ -31,6 +32,11 @@ def test_missing_value_handling(est): # missing values should still be missing, and only them assert_array_equal(np.isnan(Xt), np.isnan(X_test)) + # check that the inverse transform keep NaN + Xt_inv = est.inverse_transform(Xt) + assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test)) + assert_allclose(Xt_inv, X_test, equal_nan=True) + for i in range(X.shape[1]): # train only on non-NaN est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])]) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 51c37097adca2..e3bf4096750de 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -1210,6 +1210,20 @@ def test_quantile_transform_and_inverse(): assert_array_almost_equal(X, X_trans_inv) +def test_quantile_transform_nan(): + X = np.array([[np.nan, 0, 0, 1], + [np.nan, np.nan, 0, 0.5], + [np.nan, 1, 1, 0]]) + + transformer = QuantileTransformer(n_quantiles=10, random_state=42) + transformer.fit_transform(X) + + # check that the quantile of the first column is all NaN + assert np.isnan(transformer.quantiles_[:, 0]).all() + # all other column should not contain NaN + assert not np.isnan(transformer.quantiles_[:, 1:]).any() + + def test_robust_scaler_invalid_range(): for range_ in [ (-1, 90), From f58dcee9ee34b39d4590bcb9f531a2872f30a4b7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Mar 2018 14:07:02 +0100 Subject: [PATCH 31/34] TST fix for older numpy version --- sklearn/preprocessing/tests/test_common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/preprocessing/tests/test_common.py b/sklearn/preprocessing/tests/test_common.py index a7ec2e8da242e..2d329d0475789 100644 --- a/sklearn/preprocessing/tests/test_common.py +++ b/sklearn/preprocessing/tests/test_common.py @@ -35,7 +35,9 @@ def test_missing_value_handling(est): # check that the inverse transform keep NaN Xt_inv = est.inverse_transform(Xt) assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test)) - assert_allclose(Xt_inv, X_test, equal_nan=True) + # FIXME: we can introduce equal_nan=True in recent version of numpy. + # For the moment which just check that non-NaN values are almost equal. + assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)]) for i in range(X.shape[1]): # train only on non-NaN From 0f0348510bc6e86bea1a534638f17a18b6f5e0cc Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Mar 2018 15:31:40 +0100 Subject: [PATCH 32/34] FIX for to use nanpercentile up to 1.11 for consistent behaviour --- sklearn/utils/fixes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index ce99f82a70ce0..7d6b9cca2d9f8 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -297,7 +297,7 @@ def __getstate__(self): from numpy.ma import MaskedArray # noqa -if np_version < (1, 9): +if np_version < (1, 10): def nanpercentile(a, q): """ Compute the qth percentile of the data along the specified axis, From d0a88bdcb06e7ced4d29fa93c990e349b0097fe6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Mar 2018 15:55:29 +0100 Subject: [PATCH 33/34] my mistake --- sklearn/utils/fixes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 7d6b9cca2d9f8..f7d9d6a29f9f6 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -297,7 +297,7 @@ def __getstate__(self): from numpy.ma import MaskedArray # noqa -if np_version < (1, 10): +if np_version < (1, 11): def nanpercentile(a, q): """ Compute the qth percentile of the data along the specified axis, From 1bb000685a91736069f6f2955568fda14f64bdb1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 21 Apr 2018 18:31:15 +0200 Subject: [PATCH 34/34] Roman comments --- sklearn/preprocessing/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index adf1518672746..9909138475d7a 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -2395,7 +2395,7 @@ def _transform_col(self, X_col, quantiles, inverse): def _check_inputs(self, X, accept_sparse_negative=False): """Check inputs before fit and transform""" X = check_array(X, accept_sparse='csc', copy=self.copy, - dtype=[np.float64, np.float32], + dtype=FLOAT_DTYPES, force_all_finite='allow-nan') # we only accept positive sparse matrix when ignore_implicit_zeros is # false and that we call fit or transform.