Skip to content

[MRG] EHN handle NaN value in QuantileTransformer #10437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Apr 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
cc3bb96
EHN handle NaN value in QuantileTransformer
glemaitre Jan 9, 2018
76123c8
DOC add whats new entry
glemaitre Jan 9, 2018
530c7bf
TST relax inf/nan common test
glemaitre Jan 9, 2018
1f07963
FIX silent warning and raise an error for numpy version
glemaitre Jan 10, 2018
91c947e
TST ensure that test raise error with older numpy
glemaitre Jan 10, 2018
1c406c0
TST remove mocking
glemaitre Jan 10, 2018
ecc5048
EHN accept integer as missing values
glemaitre Jan 11, 2018
965811f
address joel comments
glemaitre Jan 12, 2018
cd28883
FIX nanpercentile for python 2
glemaitre Jan 12, 2018
3d0c389
TST test the output under numpy < 1.9
glemaitre Jan 12, 2018
a217af6
FIX nanpercentile numpy 1.8
glemaitre Jan 12, 2018
85c6268
PEP8
glemaitre Jan 12, 2018
1306992
TST check all missing values behaviour
glemaitre Jan 13, 2018
73eed7b
TST change name for consistency
glemaitre Jan 13, 2018
ecdc675
Merge remote-tracking branch 'origin/master' into is/10404
glemaitre Feb 7, 2018
d7b6cd9
EHN only accept NaN for the moment
glemaitre Feb 7, 2018
f7bc642
unecessary change
glemaitre Feb 7, 2018
84682b7
unecessary change
glemaitre Feb 7, 2018
59dfdbe
solve issue in numpy 1.8
glemaitre Feb 7, 2018
d852206
address ogrisel comments
glemaitre Feb 13, 2018
a20ac59
Address some comments
glemaitre Feb 14, 2018
f8dd6a4
TST fix common test
glemaitre Mar 2, 2018
8aa6059
TST common test for transformer letting pass nan
glemaitre Mar 15, 2018
e9b9855
Merge remote-tracking branch 'origin/master' into is/10404
glemaitre Mar 15, 2018
c745eab
TST add separate commont tests
glemaitre Mar 18, 2018
daa3a91
Merge remote-tracking branch 'origin/master' into is/10404
glemaitre Mar 18, 2018
ad878fa
TST improve testing
glemaitre Mar 18, 2018
6784c3b
TST remove parametrization on X and n_missing
glemaitre Mar 18, 2018
daa40da
address joel comments
glemaitre Mar 19, 2018
2c0ceb3
fix random state for the split training testing
glemaitre Mar 19, 2018
004b0e3
do not force percentile to be finite
glemaitre Mar 20, 2018
9ab77b6
fix
glemaitre Mar 20, 2018
33cc416
TST add test for quantile transformer
glemaitre Mar 20, 2018
f58dcee
TST fix for older numpy version
glemaitre Mar 20, 2018
0f03485
FIX for to use nanpercentile up to 1.11 for consistent behaviour
glemaitre Mar 20, 2018
d0a88bd
my mistake
glemaitre Mar 20, 2018
d554f8e
Merge branch 'master' into is/10404
glemaitre Apr 16, 2018
1bb0006
Roman comments
glemaitre Apr 21, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ Preprocessing
:issue:`10210` by :user:`Eric Chang <ericchang00>` and
:user:`Maniteja Nandana <maniteja123>`.

- :class:`preprocessing.QuantileTransformer` handles and ignores NaN values.
:issue:`10404` by :user:`Guillaume Lemaitre <glemaitre>`.

- Added the :class:`compose.TransformedTargetRegressor` which transforms
the target y before fitting a regression model. The predictions are mapped
back to the original space via an inverse transform. :issue:`9041` by
Expand Down
53 changes: 33 additions & 20 deletions sklearn/preprocessing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..utils import check_array
from ..utils.extmath import row_norms
from ..utils.extmath import _incremental_mean_and_var
from ..utils.fixes import _argmax
from ..utils.fixes import _argmax, nanpercentile
from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1,
inplace_csr_row_normalize_l2)
from ..utils.sparsefuncs import (inplace_column_scale,
Expand Down Expand Up @@ -2194,6 +2194,9 @@ class QuantileTransformer(BaseEstimator, TransformerMixin):

Notes
-----
NaNs are treated as missing values: disregarded in fit, and maintained in
transform.

For a comparison of the different scalers, transformers, and normalizers,
see :ref:`examples/preprocessing/plot_all_scaling.py
<sphx_glr_auto_examples_preprocessing_plot_all_scaling.py>`.
Expand Down Expand Up @@ -2234,7 +2237,7 @@ def _dense_fit(self, X, random_state):
size=self.subsample,
replace=False)
col = col.take(subsample_idx, mode='clip')
self.quantiles_.append(np.percentile(col, references))
self.quantiles_.append(nanpercentile(col, references))
self.quantiles_ = np.transpose(self.quantiles_)

def _sparse_fit(self, X, random_state):
Expand Down Expand Up @@ -2279,8 +2282,7 @@ def _sparse_fit(self, X, random_state):
# quantiles. Force the quantiles to be zeros.
self.quantiles_.append([0] * len(references))
else:
self.quantiles_.append(
np.percentile(column_data, references))
self.quantiles_.append(nanpercentile(column_data, references))
self.quantiles_ = np.transpose(self.quantiles_)

def fit(self, X, y=None):
Expand Down Expand Up @@ -2349,30 +2351,36 @@ def _transform_col(self, X_col, quantiles, inverse):
# for inverse transform, match a uniform PDF
X_col = output_distribution.cdf(X_col)
# find index for lower and higher bounds
lower_bounds_idx = (X_col - BOUNDS_THRESHOLD <
lower_bound_x)
upper_bounds_idx = (X_col + BOUNDS_THRESHOLD >
upper_bound_x)

with np.errstate(invalid='ignore'): # hide NaN comparison warnings
lower_bounds_idx = (X_col - BOUNDS_THRESHOLD <
lower_bound_x)
upper_bounds_idx = (X_col + BOUNDS_THRESHOLD >
upper_bound_x)

isfinite_mask = ~np.isnan(X_col)
X_col_finite = X_col[isfinite_mask]
if not inverse:
# Interpolate in one direction and in the other and take the
# mean. This is in case of repeated values in the features
# and hence repeated quantiles
#
# If we don't do this, only one extreme of the duplicated is
# used (the upper when we do assending, and the
# used (the upper when we do ascending, and the
# lower for descending). We take the mean of these two
X_col = .5 * (np.interp(X_col, quantiles, self.references_)
- np.interp(-X_col, -quantiles[::-1],
-self.references_[::-1]))
X_col[isfinite_mask] = .5 * (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, it's possible that np.ma would handle the none-missing case more efficiently than using an ad-hoc hoc mask. I've not checked.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Playing around, I think that it will trigger the same number of copy.

np.interp(X_col_finite, quantiles, self.references_)
- np.interp(-X_col_finite, -quantiles[::-1],
-self.references_[::-1]))
else:
X_col = np.interp(X_col, self.references_, quantiles)
X_col[isfinite_mask] = np.interp(X_col_finite,
self.references_, quantiles)

X_col[upper_bounds_idx] = upper_bound_y
X_col[lower_bounds_idx] = lower_bound_y
# for forward transform, match the output PDF
if not inverse:
X_col = output_distribution.ppf(X_col)
with np.errstate(invalid='ignore'): # hide NaN comparison warnings
X_col = output_distribution.ppf(X_col)
# find the value to clip the data to avoid mapping to
# infinity. Clip such that the inverse transform will be
# consistent
Expand All @@ -2387,13 +2395,15 @@ def _transform_col(self, X_col, quantiles, inverse):
def _check_inputs(self, X, accept_sparse_negative=False):
"""Check inputs before fit and transform"""
X = check_array(X, accept_sparse='csc', copy=self.copy,
dtype=[np.float64, np.float32])
dtype=FLOAT_DTYPES,
force_all_finite='allow-nan')
# we only accept positive sparse matrix when ignore_implicit_zeros is
# false and that we call fit or transform.
if (not accept_sparse_negative and not self.ignore_implicit_zeros and
(sparse.issparse(X) and np.any(X.data < 0))):
raise ValueError('QuantileTransformer only accepts non-negative'
' sparse matrices.')
with np.errstate(invalid='ignore'): # hide NaN comparison warnings
if (not accept_sparse_negative and not self.ignore_implicit_zeros
and (sparse.issparse(X) and np.any(X.data < 0))):
raise ValueError('QuantileTransformer only accepts'
' non-negative sparse matrices.')

# check the output PDF
if self.output_distribution not in ('normal', 'uniform'):
Expand Down Expand Up @@ -2582,6 +2592,9 @@ def quantile_transform(X, axis=0, n_quantiles=1000,

Notes
-----
NaNs are treated as missing values: disregarded in fit, and maintained in
transform.

For a comparison of the different scalers, transformers, and normalizers,
see :ref:`examples/preprocessing/plot_all_scaling.py
<sphx_glr_auto_examples_preprocessing_plot_all_scaling.py>`.
Expand Down
53 changes: 53 additions & 0 deletions sklearn/preprocessing/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import QuantileTransformer
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_allclose

iris = load_iris()


@pytest.mark.parametrize(
"est",
[QuantileTransformer(n_quantiles=10, random_state=42)]
)
def test_missing_value_handling(est):
# check that the preprocessing method let pass nan
rng = np.random.RandomState(42)
X = iris.data.copy()
n_missing = 50
X[rng.randint(X.shape[0], size=n_missing),
rng.randint(X.shape[1], size=n_missing)] = np.nan
X_train, X_test = train_test_split(X, random_state=1)
# sanity check
assert not np.all(np.isnan(X_train), axis=0).any()
assert np.any(np.isnan(X_train), axis=0).all()
assert np.any(np.isnan(X_test), axis=0).all()
X_test[:, 0] = np.nan # make sure this boundary case is tested

Xt = est.fit(X_train).transform(X_test)
# missing values should still be missing, and only them
assert_array_equal(np.isnan(Xt), np.isnan(X_test))

# check that the inverse transform keep NaN
Xt_inv = est.inverse_transform(Xt)
assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test))
# FIXME: we can introduce equal_nan=True in recent version of numpy.
# For the moment which just check that non-NaN values are almost equal.
assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)])

for i in range(X.shape[1]):
# train only on non-NaN
est.fit(X_train[:, [i]][~np.isnan(X_train[:, i])])
# check transforming with NaN works even when training without NaN
Xt_col = est.transform(X_test[:, [i]])
assert_array_equal(Xt_col, Xt[:, [i]])
# check non-NaN is handled as before - the 1st column is all nan
if not np.isnan(X_test[:, i]).all():
Xt_col_nonan = est.transform(
X_test[:, [i]][~np.isnan(X_test[:, i])])
assert_array_equal(Xt_col_nonan,
Xt_col[~np.isnan(Xt_col.squeeze())])
14 changes: 14 additions & 0 deletions sklearn/preprocessing/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,20 @@ def test_quantile_transform_and_inverse():
assert_array_almost_equal(X, X_trans_inv)


def test_quantile_transform_nan():
X = np.array([[np.nan, 0, 0, 1],
[np.nan, np.nan, 0, 0.5],
[np.nan, 1, 1, 0]])

transformer = QuantileTransformer(n_quantiles=10, random_state=42)
transformer.fit_transform(X)

# check that the quantile of the first column is all NaN
assert np.isnan(transformer.quantiles_[:, 0]).all()
# all other column should not contain NaN
assert not np.isnan(transformer.quantiles_[:, 1:]).any()


def test_robust_scaler_invalid_range():
for range_ in [
(-1, 90),
Expand Down
4 changes: 3 additions & 1 deletion sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression',
'RANSACRegressor', 'RadiusNeighborsRegressor',
'RandomForestRegressor', 'Ridge', 'RidgeCV']
ALLOW_NAN = ['Imputer', 'SimpleImputer', 'MICEImputer']

ALLOW_NAN = ['QuantileTransformer', 'Imputer', 'SimpleImputer', 'MICEImputer']


def _yield_non_meta_checks(name, estimator):
yield check_estimators_dtypes
Expand Down
39 changes: 39 additions & 0 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,42 @@ def __getstate__(self):
self._fill_value)
else:
from numpy.ma import MaskedArray # noqa


if np_version < (1, 11):
def nanpercentile(a, q):
"""
Compute the qth percentile of the data along the specified axis,
while ignoring nan values.

Returns the qth percentile(s) of the array elements.

Parameters
----------
a : array_like
Input array or object that can be converted to an array.
q : float in range of [0,100] (or sequence of floats)
Percentile to compute, which must be between 0 and 100
inclusive.

Returns
-------
percentile : scalar or ndarray
If `q` is a single percentile and `axis=None`, then the result
is a scalar. If multiple percentiles are given, first axis of
the result corresponds to the percentiles. The other axes are
the axes that remain after the reduction of `a`. If the input
contains integers or floats smaller than ``float64``, the output
data-type is ``float64``. Otherwise, the output data-type is the
same as that of the input. If `out` is specified, that array is
returned instead.

"""
data = np.compress(~np.isnan(a), a)
if data.size:
return np.percentile(data, q)
else:
size_q = 1 if np.isscalar(q) else len(q)
return np.array([np.nan] * size_q)
else:
from numpy import nanpercentile # noqa
16 changes: 16 additions & 0 deletions sklearn/utils/tests/test_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@

import pickle

import numpy as np
import pytest

from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_allclose

from sklearn.utils.fixes import divide
from sklearn.utils.fixes import MaskedArray
from sklearn.utils.fixes import nanpercentile


def test_divide():
Expand All @@ -24,3 +29,14 @@ def test_masked_array_obj_dtype_pickleable():
marr_pickled = pickle.loads(pickle.dumps(marr))
assert_array_equal(marr.data, marr_pickled.data)
assert_array_equal(marr.mask, marr_pickled.mask)


@pytest.mark.parametrize(
"a, q, expected_percentile",
[(np.array([1, 2, 3, np.nan]), [0, 50, 100], np.array([1., 2., 3.])),
(np.array([1, 2, 3, np.nan]), 50, 2.),
(np.array([np.nan, np.nan]), [0, 50], np.array([np.nan, np.nan]))]
)
def test_nanpercentile(a, q, expected_percentile):
percentile = nanpercentile(a, q)
assert_allclose(percentile, expected_percentile)