27
27
28
28
def _get_mask (X , value_to_mask ):
29
29
"""Compute the boolean mask X == missing_values."""
30
- if value_to_mask == "NaN" or np .isnan ( value_to_mask ) :
30
+ if value_to_mask is np .nan :
31
31
return np .isnan (X )
32
32
else :
33
33
return X == value_to_mask
@@ -148,22 +148,29 @@ def fit(self, X, y=None):
148
148
raise ValueError ("Can only impute missing values on axis 0 and 1, "
149
149
" got axis={0}" .format (self .axis ))
150
150
151
+ # Validate missing_values and convert from "NaN" to np.nan
152
+ if (isinstance (self .missing_values , six .string_types ) and
153
+ self .missing_values == "NaN" ):
154
+ missing_values = np .nan
155
+ else :
156
+ missing_values = self .missing_values
157
+
151
158
# Since two different arrays can be provided in fit(X) and
152
159
# transform(X), the imputation data will be computed in transform()
153
160
# when the imputation is done per sample (i.e., when axis=1).
154
161
if self .axis == 0 :
155
162
X = check_array (X , accept_sparse = 'csc' , dtype = np .float64 ,
156
- force_all_finite = False )
163
+ allow_nan = True , force_all_finite = True )
157
164
158
165
if sparse .issparse (X ):
159
166
self .statistics_ = self ._sparse_fit (X ,
160
167
self .strategy ,
161
- self . missing_values ,
168
+ missing_values ,
162
169
self .axis )
163
170
else :
164
171
self .statistics_ = self ._dense_fit (X ,
165
172
self .strategy ,
166
- self . missing_values ,
173
+ missing_values ,
167
174
self .axis )
168
175
169
176
return self
@@ -250,7 +257,7 @@ def _sparse_fit(self, X, strategy, missing_values, axis):
250
257
251
258
def _dense_fit (self , X , strategy , missing_values , axis ):
252
259
"""Fit the transformer on dense data."""
253
- X = check_array (X , force_all_finite = False )
260
+ X = check_array (X , allow_nan = True , force_all_finite = True )
254
261
mask = _get_mask (X , missing_values )
255
262
masked_X = ma .masked_array (X , mask = mask )
256
263
@@ -307,10 +314,18 @@ def transform(self, X):
307
314
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
308
315
The input data to complete.
309
316
"""
317
+ # Validate missing_values and convert from "NaN" to np.nan
318
+ if (isinstance (self .missing_values , six .string_types ) and
319
+ self .missing_values == "NaN" ):
320
+ missing_values = np .nan
321
+ else :
322
+ missing_values = self .missing_values
323
+
310
324
if self .axis == 0 :
311
325
check_is_fitted (self , 'statistics_' )
312
326
X = check_array (X , accept_sparse = 'csc' , dtype = FLOAT_DTYPES ,
313
- force_all_finite = False , copy = self .copy )
327
+ allow_nan = True , force_all_finite = True ,
328
+ copy = self .copy )
314
329
statistics = self .statistics_
315
330
if X .shape [1 ] != statistics .shape [0 ]:
316
331
raise ValueError ("X has %d features per sample, expected %d"
@@ -321,18 +336,19 @@ def transform(self, X):
321
336
# when the imputation is done per sample
322
337
else :
323
338
X = check_array (X , accept_sparse = 'csr' , dtype = FLOAT_DTYPES ,
324
- force_all_finite = False , copy = self .copy )
339
+ allow_nan = True , force_all_finite = True ,
340
+ copy = self .copy )
325
341
326
342
if sparse .issparse (X ):
327
343
statistics = self ._sparse_fit (X ,
328
344
self .strategy ,
329
- self . missing_values ,
345
+ missing_values ,
330
346
self .axis )
331
347
332
348
else :
333
349
statistics = self ._dense_fit (X ,
334
350
self .strategy ,
335
- self . missing_values ,
351
+ missing_values ,
336
352
self .axis )
337
353
338
354
# Delete the invalid rows/columns
@@ -352,8 +368,8 @@ def transform(self, X):
352
368
"missing values: %s" % missing )
353
369
354
370
# Do actual imputation
355
- if sparse .issparse (X ) and self . missing_values != 0 :
356
- mask = _get_mask (X .data , self . missing_values )
371
+ if sparse .issparse (X ) and missing_values != 0 :
372
+ mask = _get_mask (X .data , missing_values )
357
373
indexes = np .repeat (np .arange (len (X .indptr ) - 1 , dtype = np .int ),
358
374
np .diff (X .indptr ))[mask ]
359
375
@@ -363,7 +379,7 @@ def transform(self, X):
363
379
if sparse .issparse (X ):
364
380
X = X .toarray ()
365
381
366
- mask = _get_mask (X , self . missing_values )
382
+ mask = _get_mask (X , missing_values )
367
383
n_missing = np .sum (mask , axis = self .axis )
368
384
values = np .repeat (valid_statistics , n_missing )
369
385
0 commit comments