From fe12132efb5fd0781b590b9e0ee83a2540154a70 Mon Sep 17 00:00:00 2001 From: Xuefeng Xu Date: Tue, 21 May 2024 13:54:56 +0800 Subject: [PATCH 1/2] ENH avoid checking columns where training data is all nan in KNNImputer --- sklearn/impute/_knn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index 64f55693356d6..f690cadacbf8b 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -279,7 +279,7 @@ def transform(self, X): X_indicator = super()._transform_indicator(mask) # Removes columns where the training data is all nan - if not np.any(mask): + if not np.any(mask[:, valid_mask]): # No missing values in X if self.keep_empty_features: Xc = X From 4f6bed6a3e814d8b9dffc2679a2cf1f00b9692cb Mon Sep 17 00:00:00 2001 From: Xuefeng Xu Date: Sun, 2 Jun 2024 16:06:44 +0800 Subject: [PATCH 2/2] row_missing_idx also remove invalid column mask --- sklearn/impute/_knn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index f690cadacbf8b..9f014ef27cec1 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -293,7 +293,7 @@ def transform(self, X): # of columns, regardless of whether missing values exist in X or not. return super()._concatenate_indicator(Xc, X_indicator) - row_missing_idx = np.flatnonzero(mask.any(axis=1)) + row_missing_idx = np.flatnonzero(mask[:, valid_mask].any(axis=1)) non_missing_fix_X = np.logical_not(mask_fit_X)