Skip to content

Commit ecc5b47

Browse files
committed
Validate convert to np.nan; Check for inf
1 parent d46694c commit ecc5b47

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

sklearn/preprocessing/imputation.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
def _get_mask(X, value_to_mask):
2929
"""Compute the boolean mask X == missing_values."""
30-
if value_to_mask == "NaN" or np.isnan(value_to_mask):
30+
if value_to_mask is np.nan:
3131
return np.isnan(X)
3232
else:
3333
return X == value_to_mask
@@ -148,22 +148,29 @@ def fit(self, X, y=None):
148148
raise ValueError("Can only impute missing values on axis 0 and 1, "
149149
" got axis={0}".format(self.axis))
150150

151+
# Validate missing_values and convert from "NaN" to np.nan
152+
if (isinstance(self.missing_values, six.string_types) and
153+
self.missing_values == "NaN"):
154+
missing_values = np.nan
155+
else:
156+
missing_values = self.missing_values
157+
151158
# Since two different arrays can be provided in fit(X) and
152159
# transform(X), the imputation data will be computed in transform()
153160
# when the imputation is done per sample (i.e., when axis=1).
154161
if self.axis == 0:
155162
X = check_array(X, accept_sparse='csc', dtype=np.float64,
156-
force_all_finite=False)
163+
allow_nan=True, force_all_finite=True)
157164

158165
if sparse.issparse(X):
159166
self.statistics_ = self._sparse_fit(X,
160167
self.strategy,
161-
self.missing_values,
168+
missing_values,
162169
self.axis)
163170
else:
164171
self.statistics_ = self._dense_fit(X,
165172
self.strategy,
166-
self.missing_values,
173+
missing_values,
167174
self.axis)
168175

169176
return self
@@ -250,7 +257,7 @@ def _sparse_fit(self, X, strategy, missing_values, axis):
250257

251258
def _dense_fit(self, X, strategy, missing_values, axis):
252259
"""Fit the transformer on dense data."""
253-
X = check_array(X, force_all_finite=False)
260+
X = check_array(X, allow_nan=True, force_all_finite=True)
254261
mask = _get_mask(X, missing_values)
255262
masked_X = ma.masked_array(X, mask=mask)
256263

@@ -307,10 +314,18 @@ def transform(self, X):
307314
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
308315
The input data to complete.
309316
"""
317+
# Validate missing_values and convert from "NaN" to np.nan
318+
if (isinstance(self.missing_values, six.string_types) and
319+
self.missing_values == "NaN"):
320+
missing_values = np.nan
321+
else:
322+
missing_values = self.missing_values
323+
310324
if self.axis == 0:
311325
check_is_fitted(self, 'statistics_')
312326
X = check_array(X, accept_sparse='csc', dtype=FLOAT_DTYPES,
313-
force_all_finite=False, copy=self.copy)
327+
allow_nan=True, force_all_finite=True,
328+
copy=self.copy)
314329
statistics = self.statistics_
315330
if X.shape[1] != statistics.shape[0]:
316331
raise ValueError("X has %d features per sample, expected %d"
@@ -321,18 +336,19 @@ def transform(self, X):
321336
# when the imputation is done per sample
322337
else:
323338
X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES,
324-
force_all_finite=False, copy=self.copy)
339+
allow_nan=True, force_all_finite=True,
340+
copy=self.copy)
325341

326342
if sparse.issparse(X):
327343
statistics = self._sparse_fit(X,
328344
self.strategy,
329-
self.missing_values,
345+
missing_values,
330346
self.axis)
331347

332348
else:
333349
statistics = self._dense_fit(X,
334350
self.strategy,
335-
self.missing_values,
351+
missing_values,
336352
self.axis)
337353

338354
# Delete the invalid rows/columns
@@ -352,8 +368,8 @@ def transform(self, X):
352368
"missing values: %s" % missing)
353369

354370
# Do actual imputation
355-
if sparse.issparse(X) and self.missing_values != 0:
356-
mask = _get_mask(X.data, self.missing_values)
371+
if sparse.issparse(X) and missing_values != 0:
372+
mask = _get_mask(X.data, missing_values)
357373
indexes = np.repeat(np.arange(len(X.indptr) - 1, dtype=np.int),
358374
np.diff(X.indptr))[mask]
359375

@@ -363,7 +379,7 @@ def transform(self, X):
363379
if sparse.issparse(X):
364380
X = X.toarray()
365381

366-
mask = _get_mask(X, self.missing_values)
382+
mask = _get_mask(X, missing_values)
367383
n_missing = np.sum(mask, axis=self.axis)
368384
values = np.repeat(valid_statistics, n_missing)
369385

sklearn/preprocessing/tests/test_imputation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sklearn.utils.testing import assert_equal
66
from sklearn.utils.testing import assert_array_equal
77
from sklearn.utils.testing import assert_raises
8+
from sklearn.utils.testing import assert_raise_message
89
from sklearn.utils.testing import assert_false
910

1011
from sklearn.preprocessing.imputation import Imputer
@@ -357,3 +358,12 @@ def test_imputation_copy():
357358

358359
# Note: If X is sparse and if missing_values=0, then a (dense) copy of X is
359360
# made, even if copy=False.
361+
362+
# Raise a proper error message if input contains infinity
363+
X = [[np.inf, 8, 9, np.nan], [np.nan, 10, 10, 0], [10, 11, 9, 11]]
364+
assert_raise_message(ValueError, "Input contains infinity",
365+
Imputer(axis=0, missing_values="NaN").fit_transform,
366+
X)
367+
assert_raise_message(ValueError, "Input contains infinity",
368+
Imputer(axis=1, missing_values=np.nan).fit_transform,
369+
X)

0 commit comments

Comments
 (0)