Skip to content

[MRG] ENH: Adds inverse_transform to ColumnTransformer #11639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 55 commits into
base: main
Choose a base branch
from

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Jul 20, 2018

Reference Issues/PRs

Fixes #11463

What does this implement/fix? Explain your changes.

  1. Running inverse_transform with overlap or drop will raise a ValueError
  2. _calculate_inverse_indices is used to connect indices from the output space back to the input space.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet reviewed.

for name, trans, cols in self.transformers:
col_indices = _get_column_indices(X, cols)
if not all_indexes.isdisjoint(set(col_indices)):
self._invertible = (False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't appear to be covered by tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_column_transformer_inverse_transform_with_overlaping_slices should cover this, I added the ValueError message in the assertion.


if not Xs:
# All transformers are None
return np.zeros((X.shape[0], 0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not covered by tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added test_column_transformer_inverse_transform_all_transformers_drop to cover this.

inverse_Xs[:, indices] = inverse_X

if self._X_is_sparse:
return sparse.csr_matrix(inverse_Xs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not covered by tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an assert to test_column_transformer_sparse_array to cover this.

Returns
-------
Xt : array-like, shape = [n_samples, n_features]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that it is a pandas DataFrame when ..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not resolved

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet a full review

try:
import pandas as pd
return pd.DataFrame(inverse_Xs, columns=self._X_columns)
except ImportError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes no sense. Either we promise pandas or we don't. Changing the return type on the basis of not having a dependency can't happen.

input_indices = []
for name, trans, cols in self.transformers:
col_indices = _get_column_indices(X, cols)
if not all_indexes.isdisjoint(set(col_indices)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm probably missing something. I can't see where you update all_indexes to be non-empty. If I'm right, add tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! It turns out pytest.raises(..., match=...) was need to properly test the Exception.

input_indices.append(col_indices)

# check remainder
remainder_indices = self._remainder[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make _remainder a namedtuple so that this is more legible

input_indices.append(remainder_indices)

self._input_indices = input_indices
self._X_features = X.shape[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps name this _n_features_in

@thomasjpfan thomasjpfan changed the title ENH: Adds inverse_transform to ColumnTransformer [MRG] ENH: Adds inverse_transform to ColumnTransformer Jul 27, 2018
Copy link
Member

@jorisvandenbossche jorisvandenbossche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on it!
Added a few comments (didn't look at the tests yet)

col_indices_set = set(col_indices)
if not all_indexes.isdisjoint(col_indices_set):
self._invert_error = ("Unable to invert: transformers "
"contain overlaping columns")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overlaping -> overlapping

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think the "Unable to invert" is already included into the message in inverse_transform ?

"contain overlaping columns")
return
if trans == 'drop':
self._invert_error = "'{}' drops columns".format(name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add something explicitly saying that dropping columns is not supported

Private function to calcuate indicies for inverse_transform
"""
# checks for overlap
all_indexes = set()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit, but maybe also use 'indices' instead of 'indexes' since all other variables do that?


self._input_indices = input_indices
self._n_features_in = X.shape[1]
self._X_columns = X.columns if hasattr(X, 'columns') else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe getattr(X, 'columns', None) ?

self._output_indices = []
cur_index = 0

Xs = self._fit_transform(X[0:1], None, _transform_one, fitted=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of transforming here again, we could also save the dimensions of the outputs in self.fit_transform itself?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or actually, you can also pass Xs if we want to keep the code here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great idea! Thank you!

trans = FunctionTransformer(
validate=False, accept_sparse=True, check_inverse=False)

inv_transformers.append((name, trans, sub, get_weight(name)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems name is not used below, so not needed to pass it?

inverse_Xs = sparse.lil_matrix((Xs[0].shape[0],
self._n_features_in))
else:
inverse_Xs = np.zeros((Xs[0].shape[0], self._n_features_in))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zeros -> empty?

else:
inverse_Xs[:, indices] = inverse_X.toarray()
else:
if inverse_X.ndim == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this possible?

Copy link
Member Author

@thomasjpfan thomasjpfan Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to inverse_Xs[:, indices] = ... or inverse_x.ndim == 1?

  1. inverse_Xs[:, indices] = ... runs when inverse_X is sparse and X is not sparse. test_column_transformer_sparse_array was updated to test this.
  2. inverse_x.ndim == 1 runs when inverse_X is not sparse and only has one dimension. test_column_transformer_sparse_stacking tests for this use case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the second about inverse_x being 1-dimensional. I would think that any valid sklearn transformer should always return 2D output? (but maybe I am overlooking something). And so I would expect input for inverse_transform to have the same constraint.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preprocessing.LabelEncoder's transform and inverse_transform returns 1-D arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LabelEncoder shouldn't be used on X

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR was updated to remove the check for 1D outputs.

@sklearn-lgtm
Copy link

This pull request introduces 1 alert when merging 9d0cd00 into 7166cd5 - view on LGTM.com

new alerts:

  • 1 for Unreachable code

Comment posted by LGTM.com

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could come up with a helpful example of this. I'm not sure it's necessary. Noting the available functionality in the user guide might be worthwhile.

Only a partial review.

else:
inverse_Xs[:, indices] = inverse_X.toarray()
else:
if inverse_X.ndim == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LabelEncoder shouldn't be used on X

@@ -609,6 +609,12 @@ def _transform_one(transformer, X, y, weight, **fit_params):
return res * weight


def _inverse_transform_one(transformer, X, weight, **fit_params):
weight = weight or 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather make 1 if weight is None else weight more explicitly... bool(weight) is not a good way to check for None

@thomasjpfan
Copy link
Member Author

A motivating example would be nice. I will see if there is a place to add it in compose.rst.

@thomasjpfan thomasjpfan changed the title [MRG] ENH: Adds inverse_transform to ColumnTransformer [WIP] ENH: Adds inverse_transform to ColumnTransformer Oct 12, 2018
@jnothman
Copy link
Member

jnothman commented May 1, 2019

This modifies whats_new/v0.20.rst and should be updated

@glemaitre
Copy link
Member

I have a use case internally where I want to use the preprocessor to process X. From this X_proc, we compute some quantiles which I would like to inverse to get the original scale.

@thomasjpfan
Copy link
Member Author

@glemaitre thank you for the idea! Let’s see if there is a way to integrate it into one of our examples 🤔

@amueller
Copy link
Member

amueller commented Aug 6, 2019

do you want to merge with master for another round of reviews?

@thomasjpfan
Copy link
Member Author

I have a use case internally where I want to use the preprocessor to process X. From this X_proc, we compute some quantiles which I would like to inverse to get the original scale.

@glemaitre What kind of insights do you get when you have the quantiles in the original scale?

@@ -456,6 +461,60 @@ def _fit_transform(self, X, y, func, fitted=False):
else:
raise

def _calculate_inverse_indices(self, X, Xs):
"""
Private function to calcuate indices for inverse_transform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor typo:

Suggested change
Private function to calcuate indices for inverse_transform
Private function to calculate indices for inverse_transform

@glemaitre
Copy link
Member

@glemaitre What kind of insights do you get when you have the quantiles in the original scale?

Ouch. I don't remember what was my use case now. I should have directly given what I was programming at that time. It could have been linked to some neuroscience stuff but I am unsure now.

@judahrand
Copy link

Is this merge-able?

@thomasjpfan
Copy link
Member Author

@Jude188 Not right now, do you have an example where this feature would be useful to you?

@judahrand
Copy link

judahrand commented Aug 7, 2020

@thomasjpfan I've got a case where I'm using Sklearn for only data preprocessing and not using an estimator. This is because the model that I have doesn't really have a y or target at 'learning' time. However, I want to be able to push the data that comes out of my model back through the preprocessing pipeline in order to recover the 'true' results. This basically boils down to, I have a matrix of values some of which I want to log transform some of which I do not but I also need to be able to undo that transform since the whole matrix comes out of my simulation model.

I'm sure I did a terrible job of explaining that!

@MarcBresson
Copy link
Contributor

hello, any news on that PR?

I'm working on interpretability and started working on adding an inverse_transform for ColumnTransformer. I stumbled upon that PR and it would be extremely helpful to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support inverse_transform in ColumnTransformer