Skip to content

[MRG] Add KNN strategy for imputation #4844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,10 @@ values. However, this comes at the price of losing data which may be valuable
i.e., to infer them from the known part of the data.

The :class:`Imputer` class provides basic strategies for imputing missing
values, either using the mean, the median or the most frequent value of
the row or column in which the missing values are located. This class
also allows for different missing values encodings.
values. It can use the mean, the median, the most frequent value of
the row or column in which the missing values are located. Alternatively it can fill
with the mean of only the k-nearest neighbors computed using samples without missing
values. The placeholder for missing values is configurable.

The following snippet demonstrates how to replace missing values,
encoded as ``np.nan``, using the mean value of the columns (axis 0)
Expand All @@ -399,8 +400,8 @@ that contain the missing values::
>>> import numpy as np
>>> from sklearn.preprocessing import Imputer
>>> imp = Imputer(missing_values='NaN', strategy='mean', axis=0)
>>> imp.fit([[1, 2], [np.nan, 3], [7, 6]])
Imputer(axis=0, copy=True, missing_values='NaN', strategy='mean', verbose=0)
>>> imp.fit([[1, 2], [np.nan, 3], [7, 6]]) # doctest: +NORMALIZE_WHITESPACE
Imputer(axis=0, copy=True, missing_values='NaN', n_neighbors=1, strategy='mean', verbose=0)
>>> X = [[np.nan, 2], [6, np.nan], [7, 6]]
>>> print(imp.transform(X)) # doctest: +ELLIPSIS
[[ 4. 2. ]
Expand All @@ -412,8 +413,8 @@ The :class:`Imputer` class also supports sparse matrices::
>>> import scipy.sparse as sp
>>> X = sp.csc_matrix([[1, 2], [0, 3], [7, 6]])
>>> imp = Imputer(missing_values=0, strategy='mean', axis=0)
>>> imp.fit(X)
Imputer(axis=0, copy=True, missing_values=0, strategy='mean', verbose=0)
>>> imp.fit(X) # doctest: +NORMALIZE_WHITESPACE
Imputer(axis=0, copy=True, missing_values=0, n_neighbors=1, strategy='mean', verbose=0)
>>> X_test = sp.csc_matrix([[0, 2], [6, 0], [7, 6]])
>>> print(imp.transform(X_test)) # doctest: +ELLIPSIS
[[ 4. 2. ]
Expand All @@ -424,5 +425,9 @@ Note that, here, missing values are encoded by 0 and are thus implicitly stored
in the matrix. This format is thus suitable when there are many more missing
values than observed values.

When using ``strategy=knn``, only samples without any missing features will be used for imputation.
If all samples
have missing features, this strategy will fail.

:class:`Imputer` can be used in a Pipeline as a way to build a composite
estimator that supports imputation. See :ref:`example_missing_values.py`
60 changes: 38 additions & 22 deletions examples/missing_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,81 @@
Imputing does not always improve the predictions, so please check via cross-validation.
Sometimes dropping rows or using marker values is more effective.

Missing values can be replaced by the mean, the median or the most frequent
value using the ``strategy`` hyper-parameter.
Missing values can be replaced by the mean, the median, the most frequent
value or the mean of values of k-nearest neighbors using the ``strategy`` hyper-parameter.
The median is a more robust estimator for data with high magnitude variables
which could dominate results (otherwise known as a 'long tail').

Script output::

Score with the entire dataset = 0.56
Score without the samples containing missing values = 0.48
Score after imputation of the missing values = 0.55
Score with the entire dataset = 0.43
Score without the samples containing missing values = 0.35
Score after mean imputation of the missing values = 0.42
Score after knn imputation with 7 neighbors of the missing values = 0.43

In this case, imputing helps the classifier get close to the original score.

"""
import numpy as np

from sklearn.datasets import load_boston
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import Imputer
from sklearn.cross_validation import cross_val_score

rng = np.random.RandomState(0)

dataset = load_boston()
dataset = load_diabetes()
X_full, y_full = dataset.data, dataset.target
n_samples = X_full.shape[0]
n_features = X_full.shape[1]

#Create a random matrix to randomly make missing values
missing_matrix = rng.rand(n_samples, n_features)

# each sample has (1-th)^n_features of probability to have full features
th = 0.14
mask = missing_matrix < th
missing_samples = mask.any(axis=1)
full_percentage = (n_samples - missing_samples.sum()) / float(n_samples)
print("Percentage of samples with full features: %f" % full_percentage)

# Estimate the score on the entire dataset, with no missing values

estimator = RandomForestRegressor(random_state=0, n_estimators=100)
score = cross_val_score(estimator, X_full, y_full).mean()
print("Score with the entire dataset = %.2f" % score)

# Add missing values in 75% of the lines
missing_rate = 0.75
n_missing_samples = np.floor(n_samples * missing_rate)
missing_samples = np.hstack((np.zeros(n_samples - n_missing_samples,
dtype=np.bool),
np.ones(n_missing_samples,
dtype=np.bool)))
rng.shuffle(missing_samples)
missing_features = rng.randint(0, n_features, n_missing_samples)

# Estimate the score without the lines containing missing values

X_filtered = X_full[~missing_samples, :]
y_filtered = y_full[~missing_samples]
estimator = RandomForestRegressor(random_state=0, n_estimators=100)
score = cross_val_score(estimator, X_filtered, y_filtered).mean()
print("Score without the samples containing missing values = %.2f" % score)

# Estimate the score after imputation of the missing values
# Estimate the score after mean imputation of the missing values

X_missing = X_full.copy()
X_missing[np.where(missing_samples)[0], missing_features] = 0
X_missing[mask] = np.nan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any particular reason for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now there is no missing_features, and every entry has a probability to be missing, which can be set at th.
mask has same shape as dataset, and it has entry 'True' for those places that would be missing. So I think use X_missing[mask] is easier.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Sorry, I meant from zero to np.nan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Because in original dataset there might be 0 entries, and if I label missing entries as 0, those entries will also be imputed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.

y_missing = y_full.copy()
estimator = Pipeline([("imputer", Imputer(missing_values=0,
strategy="mean",

estimator = Pipeline([("imputer", Imputer(strategy="mean",
axis=0)),
("forest", RandomForestRegressor(random_state=0,
n_estimators=100))])
score = cross_val_score(estimator, X_missing, y_missing).mean()
print("Score after imputation of the missing values = %.2f" % score)
print("Score after mean imputation of the missing values = %.2f" % score)

# Estimate the score after knn imputation of the missing values

neigh = 7 # Number of neighbors to be used
estimator2 = Pipeline([("imputer", Imputer(strategy="knn",
axis=0, n_neighbors=neigh)),
("forest", RandomForestRegressor(random_state=0,
n_estimators=100))])
score = cross_val_score(estimator2, X_missing, y_missing).mean()
print("Score after knn imputation with %d neighbors of the missing values ="
" %.2f" % (neigh, score))
107 changes: 93 additions & 14 deletions sklearn/preprocessing/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# License: BSD 3 clause

import warnings

import numpy as np
import numpy.ma as ma
from scipy import sparse
Expand All @@ -14,7 +13,7 @@
from ..utils.fixes import astype
from ..utils.sparsefuncs import _get_median
from ..utils.validation import check_is_fitted

from ..utils import gen_batches
from ..externals import six

zip = six.moves.zip
Expand Down Expand Up @@ -61,6 +60,7 @@ def _most_frequent(array, extra_value, n_repeat):
return extra_value



class Imputer(BaseEstimator, TransformerMixin):
"""Imputation transformer for completing missing values.

Expand All @@ -82,12 +82,15 @@ class Imputer(BaseEstimator, TransformerMixin):
the axis.
- If "most_frequent", then replace missing using the most frequent
value along the axis.
- If "knn", then replace missing using the mean of the k-nearest
neighbors along the axis. Only samples with no missing values are
considered as neighbors.

axis : integer, optional (default=0)
The axis along which to impute.

- If `axis=0`, then impute along columns.
- If `axis=1`, then impute along rows.
- If ``axis=0``, then impute along columns.
- If ``axis=1``, then impute along rows.

verbose : integer, optional (default=0)
Controls the verbosity of the imputer.
Expand All @@ -99,13 +102,18 @@ class Imputer(BaseEstimator, TransformerMixin):

- If X is not an array of floating values;
- If X is sparse and `missing_values=0`;
- If `axis=0` and X is encoded as a CSR matrix;
- If `axis=1` and X is encoded as a CSC matrix.
- If ``axis=0`` and X is encoded as a CSR matrix;
- If ``axis=1`` and X is encoded as a CSC matrix.

n_neighbors : int, optional (default=1)
Controls the number of nearest neighbors used to compute the mean
along the axis. Only used when ``strategy=knn``

Attributes
----------
statistics_ : array of shape (n_features,)
The imputation fill value for each feature if axis == 0.
If ``strategy=knn``, then it contains those samples having no missing value.

Notes
-----
Expand All @@ -114,14 +122,16 @@ class Imputer(BaseEstimator, TransformerMixin):
- When ``axis=1``, an exception is raised if there are rows for which it is
not possible to fill in the missing values (e.g., because they only
contain missing values).
- Knn strategy currently doesn't support sparse matrix.
"""
def __init__(self, missing_values="NaN", strategy="mean",
axis=0, verbose=0, copy=True):
axis=0, verbose=0, copy=True, n_neighbors=1):
self.missing_values = missing_values
self.strategy = strategy
self.axis = axis
self.verbose = verbose
self.copy = copy
self.n_neighbors = n_neighbors

def fit(self, X, y=None):
"""Fit the imputer on X.
Expand All @@ -138,7 +148,7 @@ def fit(self, X, y=None):
Returns self.
"""
# Check parameters
allowed_strategies = ["mean", "median", "most_frequent"]
allowed_strategies = ["mean", "median", "most_frequent", "knn"]
if self.strategy not in allowed_strategies:
raise ValueError("Can only use these strategies: {0} "
" got strategy={1}".format(allowed_strategies,
Expand Down Expand Up @@ -248,6 +258,12 @@ def _sparse_fit(self, X, strategy, missing_values, axis):

return most_frequent

# KNN
elif strategy == "knn":
raise ValueError("strategy='knn' does not support sparse "
"matrix input")


def _dense_fit(self, X, strategy, missing_values, axis):
"""Fit the transformer on dense data."""
X = check_array(X, force_all_finite=False)
Expand Down Expand Up @@ -299,6 +315,27 @@ def _dense_fit(self, X, strategy, missing_values, axis):

return most_frequent

# KNN
elif strategy == "knn":

if axis == 1:
X = X.transpose()
mask = mask.transpose()

# Get samples with complete features
full_data = X[np.logical_not(mask.any(axis=1))]
if full_data.size == 0:
raise ValueError("There is no sample with complete data.")
if full_data.shape[0] < self.n_neighbors:
raise ValueError("There are only %d complete samples, "
"but n_neighbors=%d."
% (full_data.shape[0], self.n_neighbors))
# Transpose back
if axis == 1:
full_data = full_data.transpose()

return full_data

def transform(self, X):
"""Impute all missing values in X.

Expand Down Expand Up @@ -341,7 +378,9 @@ def transform(self, X):
valid_mask = np.logical_not(invalid_mask)
valid_statistics = statistics[valid_mask]
valid_statistics_indexes = np.where(valid_mask)[0]
missing = np.arange(X.shape[not self.axis])[invalid_mask]

if self.strategy != "knn":
missing = np.arange(X.shape[not self.axis])[invalid_mask]

if self.axis == 0 and invalid_mask.any():
if self.verbose:
Expand All @@ -366,13 +405,53 @@ def transform(self, X):

mask = _get_mask(X, self.missing_values)
n_missing = np.sum(mask, axis=self.axis)
values = np.repeat(valid_statistics, n_missing)

if self.axis == 0:
coordinates = np.where(mask.transpose())[::-1]
if self.strategy == 'knn':
if self.axis == 1:
X = X.transpose()
mask = mask.transpose()
statistics = statistics.transpose()

batch_size = 1 # set batch size for block query
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably simplify the code if we stay with batch-size=1, right?


missing_index = np.where(mask.any(axis=1))[0]
D2 = np.empty_like(np.zeros([batch_size, statistics.shape[0],
statistics.shape[1]]))

# Preallocate output array for np.multiply(test1, test1, out=D2)
for sl in gen_batches(len(missing_index), batch_size):
X_sl = X[missing_index[sl]]
mask_sl = mask[missing_index[sl]]
X_sl[mask_sl] = np.nan
impute_dist = X_sl[:][:, np.newaxis, :] - statistics

# For the last slice, the length may not be the same
# as batch_size
if impute_dist.shape != D2.shape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have batch-size=1 this doesn't happen, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

D2 = np.empty_like(impute_dist)

np.multiply(impute_dist, impute_dist, out=D2)
D2[np.isnan(D2)] = 0
missing_row, missing_col = np.where(np.isnan(X_sl))
sqdist = D2.sum(axis=2)
target_index = np.argsort(sqdist, axis=1)[:, :self.n_neighbors]
means = np.mean(statistics[target_index], axis=1)
X_sl[missing_row, missing_col] = means[np.where(np.isnan(X_sl))[0],
missing_col]
X[missing_index[sl]] = X_sl

if self.axis == 1:
X = X.transpose()

else:
coordinates = mask
values = np.repeat(valid_statistics, n_missing)

X[coordinates] = values
if self.axis == 0:
coordinates = np.where(mask.transpose())[::-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a [::-1]?

else:
coordinates = mask

X[coordinates] = values

return X

Loading