Skip to content

Commit 81e3e68

Browse files
committed
RFC: Addresses PR comments
1 parent 7bc3be0 commit 81e3e68

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

sklearn/compose/_column_transformer.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -376,45 +376,46 @@ def _fit_transform(self, X, y, func, fitted=False):
376376
else:
377377
raise
378378

379-
def _calculate_inverse_indices(self, X):
379+
def _calculate_inverse_indices(self, X, Xs):
380380
"""
381381
Private function to calcuate indicies for inverse_transform
382382
"""
383383
# checks for overlap
384-
all_indexes = set()
384+
all_indices = set()
385385
input_indices = []
386386
for name, trans, cols in self.transformers:
387387
col_indices = _get_column_indices(X, cols)
388388
col_indices_set = set(col_indices)
389-
if not all_indexes.isdisjoint(col_indices_set):
390-
self._invert_error = ("Unable to invert: transformers "
391-
"contain overlaping columns")
389+
if not all_indices.isdisjoint(col_indices_set):
390+
self._invert_error = (
391+
"transformers contain overlapping columns")
392392
return
393393
if trans == 'drop':
394-
self._invert_error = "'{}' drops columns".format(name)
394+
self._invert_error = ("dropping columns is not supported. "
395+
"'{}' drops columns".format(name))
395396
return
396397
input_indices.append(col_indices)
397-
all_indexes.update(col_indices_set)
398+
all_indices.update(col_indices_set)
398399

399400
# check remainder
400401
remainder_indices = self._remainder.indices
401402
if (remainder_indices is not None
402403
and self._remainder.transformer == 'drop'):
403-
self._invert_error = "remainder drops columns"
404+
self._invert_error = ("dropping columns is not supported. "
405+
"remainder drops columns")
404406
return
405407

406408
if remainder_indices is not None:
407409
input_indices.append(remainder_indices)
408410

409411
self._input_indices = input_indices
410412
self._n_features_in = X.shape[1]
411-
self._X_columns = X.columns if hasattr(X, 'columns') else None
413+
self._X_columns = getattr(X, 'columns', None)
412414
self._X_is_sparse = sparse.issparse(X)
413415
self._invert_error = ""
414416
self._output_indices = []
415417
cur_index = 0
416418

417-
Xs = self._fit_transform(X[0:1], None, _transform_one, fitted=True)
418419
for X_transform in Xs:
419420
X_features = X_transform.shape[-1]
420421
self._output_indices.append(
@@ -492,7 +493,7 @@ def fit_transform(self, X, y=None):
492493
self.sparse_output_ = False
493494

494495
self._update_fitted_transformers(transformers)
495-
self._calculate_inverse_indices(X)
496+
self._calculate_inverse_indices(X, Xs)
496497
self._validate_output(Xs)
497498

498499
return _hstack(list(Xs), self.sparse_output_)
@@ -555,12 +556,12 @@ def inverse_transform(self, X):
555556
trans = FunctionTransformer(
556557
validate=False, accept_sparse=True, check_inverse=False)
557558

558-
inv_transformers.append((name, trans, sub, get_weight(name)))
559+
inv_transformers.append((trans, sub, get_weight(name)))
559560

560561
Xs = Parallel(n_jobs=self.n_jobs)(
561562
delayed(_inverse_transform_one)(
562563
trans, X_sel, weight)
563-
for _, trans, X_sel, weight in inv_transformers)
564+
for trans, X_sel, weight in inv_transformers)
564565

565566
if not Xs:
566567
# All transformers are None
@@ -570,7 +571,7 @@ def inverse_transform(self, X):
570571
inverse_Xs = sparse.lil_matrix((Xs[0].shape[0],
571572
self._n_features_in))
572573
else:
573-
inverse_Xs = np.zeros((Xs[0].shape[0], self._n_features_in))
574+
inverse_Xs = np.empty((Xs[0].shape[0], self._n_features_in))
574575
for indices, inverse_X in zip(self._input_indices, Xs):
575576
if sparse.issparse(inverse_X):
576577
if self._X_is_sparse:

sklearn/compose/tests/test_column_transformer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def test_column_transformer_inverse_transform_with_drop():
159159
('trans2', 'drop', [1])])
160160
res = ct.fit_transform(X_array)
161161

162-
error_msg = "Unable to invert: 'trans2' drops columns"
162+
error_msg = ("Unable to invert: dropping columns is "
163+
"not supported. 'trans2' drops columns")
163164
with pytest.raises(ValueError, match=error_msg):
164165
ct.inverse_transform(res)
165166

@@ -170,7 +171,7 @@ def test_column_transformer_inverse_transform_with_overlaping_slices():
170171
('trans2', Trans(), [0])])
171172
res = ct.fit_transform(X_array)
172173

173-
error_msg = "Unable to invert: transformers contain overlaping columns"
174+
error_msg = "Unable to invert: transformers contain overlapping columns"
174175
with pytest.raises(ValueError, match=error_msg):
175176
ct.inverse_transform(res)
176177

@@ -180,7 +181,8 @@ def test_column_transformer_inverse_transform_with_remainder_drops():
180181
ct = ColumnTransformer([('trans1', Trans(), [0])])
181182
res = ct.fit_transform(X_array)
182183

183-
error_msg = "Unable to invert: remainder drops columns"
184+
error_msg = ("Unable to invert: dropping columns is not supported. "
185+
"remainder drops columns")
184186
with pytest.raises(ValueError, match=error_msg):
185187
ct.inverse_transform(res)
186188

0 commit comments

Comments
 (0)