Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ Enhancements
Bug fixes
.........

- Fixed a bug where :func:`sklearn.model_selection.BaseSearchCV.inverse_transform`
returns self.best_estimator_.transform() instead of self.best_estimator_.inverse_transform()
:issue:`8344` by :user:`Akshay Gupta <Akshay0724>`


- Fixed a bug where :class:`sklearn.linear_model.RandomizedLasso` and
:class:`sklearn.linear_model.RandomizedLogisticRegression` breaks for
sparse input.
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def inverse_transform(self, Xt):

"""
self._check_is_fitted('inverse_transform')
return self.best_estimator_.transform(Xt)
return self.best_estimator_.inverse_transform(Xt)

@property
def classes_(self):
Expand Down
17 changes: 15 additions & 2 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,15 @@ def fit(self, X, Y):
def predict(self, T):
return T.shape[0]

def transform(self, X):
return X + self.foo_param

def inverse_transform(self, X):
return X - self.foo_param

predict_proba = predict
predict_log_proba = predict
decision_function = predict
transform = predict
inverse_transform = predict

def score(self, X=None, Y=None):
if self.foo_param > 1:
Expand Down Expand Up @@ -1305,3 +1309,12 @@ def _pop_time_keys(cv_results):
per_param_scores[1])
assert_array_almost_equal(per_param_scores[2],
per_param_scores[3])


def test_transform_inverse_transform_round_trip():
clf = MockClassifier()
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, verbose=3)

grid_search.fit(X, y)
X_round_trip = grid_search.inverse_transform(grid_search.transform(X))
assert_array_equal(X, X_round_trip)