Skip to content

[MRG+1] Do not transform y #9180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 22, 2017
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,12 @@ API changes summary
- ``utils.stats.rankdata``
- ``neighbors.approximate.LSHForest``

- Deprecate the ``y`` parameter in `transform` and `inverse_transform`.
The method should not accept ``y`` parameter, as it's used at the prediction time.
:issue:`8174` by :user:`Tahar Zanouda <tzano>`, `Alexandre Gramfort`_
and `Raghav RV`_.


.. _changes_0_18_1:

Version 0.18.1
Expand Down
5 changes: 2 additions & 3 deletions sklearn/decomposition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def fit(X, y=None):
Returns the instance itself.
"""


def transform(self, X, y=None):
def transform(self, X):
"""Apply dimensionality reduction to X.

X is projected on the first principal components previously extracted
Expand Down Expand Up @@ -134,7 +133,7 @@ def transform(self, X, y=None):
X_transformed /= np.sqrt(self.explained_variance_)
return X_transformed

def inverse_transform(self, X, y=None):
def inverse_transform(self, X):
"""Transform data back to its original space.

In other words, return an input X_original whose transform would be X.
Expand Down
14 changes: 12 additions & 2 deletions sklearn/decomposition/fastica_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
# Authors: Pierre Lafaye de Micheaux, Stefan van der Walt, Gael Varoquaux,
# Bertrand Thirion, Alexandre Gramfort, Denis A. Engemann
# License: BSD 3 clause

import warnings

import numpy as np
from scipy import linalg

from ..base import BaseEstimator, TransformerMixin
from ..externals import six
from ..externals.six import moves
from ..externals.six import string_types
from ..utils import check_array, as_float_array, check_random_state
from ..utils.validation import check_is_fitted
from ..utils.validation import FLOAT_DTYPES
Expand Down Expand Up @@ -528,22 +531,29 @@ def fit(self, X, y=None):
self._fit(X, compute_sources=False)
return self

def transform(self, X, y=None, copy=True):
def transform(self, X, y='deprecated', copy=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why deprecation here and not elsewhere?

Copy link
Member Author

@raghavrv raghavrv Jun 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for that will push a commit fixing this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because here there is copy param so we cannot just remove y from the signature like above.

"""Recover the sources from X (apply the unmixing matrix).

Parameters
----------
X : array-like, shape (n_samples, n_features)
Data to transform, where n_samples is the number of samples
and n_features is the number of features.

copy : bool (optional)
If False, data passed to fit are overwritten. Defaults to True.
y : (ignored)
.. deprecated:: 0.19
This parameter will be removed in 0.21.

Returns
-------
X_new : array-like, shape (n_samples, n_components)
"""
if not isinstance(y, string_types) or y != 'deprecated':
warnings.warn("The parameter y on transform() is "
"deprecated since 0.19 and will be removed in 0.21",
DeprecationWarning)

check_is_fitted(self, 'mixing_')

X = check_array(X, copy=copy, dtype=FLOAT_DTYPES)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def fit_transform(self, X, y=None):
X = self._fit(X)
return np.dot(X, self.components_.T)

def inverse_transform(self, X, y=None):
def inverse_transform(self, X):
"""Transform data back to its original space.

Returns an array X_original whose transform would be X.
Expand Down
3 changes: 1 addition & 2 deletions sklearn/feature_extraction/dict_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def inverse_transform(self, X, dict_type=dict):

return dicts

def transform(self, X, y=None):
def transform(self, X):
"""Transform feature->value dicts to array or sparse matrix.

Named features not encountered during fit or fit_transform will be
Expand All @@ -281,7 +281,6 @@ def transform(self, X, y=None):
X : Mapping or iterable over Mappings, length = n_samples
Dict(s) or Mapping(s) from feature names (arbitrary Python
objects) to feature values (strings or convertible to dtype).
y : (ignored)

Returns
-------
Expand Down
3 changes: 1 addition & 2 deletions sklearn/feature_extraction/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def fit(self, X=None, y=None):
self._validate_params(self.n_features, self.input_type)
return self

def transform(self, raw_X, y=None):
def transform(self, raw_X):
"""Transform a sequence of instances to a scipy.sparse matrix.

Parameters
Expand All @@ -137,7 +137,6 @@ def transform(self, raw_X, y=None):
the input_type constructor argument) which will be hashed.
raw_X need not support the len function, so it can be the result
of a generator; n_samples is determined on the fly.
y : (ignored)

Returns
-------
Expand Down
1 change: 0 additions & 1 deletion sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ def transform(self, X):
-------
X : scipy.sparse matrix, shape = (n_samples, self.n_features)
Document-term matrix.

"""
if isinstance(X, six.string_types):
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def fit(self, X, y=None):
size=self.n_components)
return self

def transform(self, X, y=None):
def transform(self, X):
"""Apply the approximate feature map to X.

Parameters
Expand Down Expand Up @@ -178,7 +178,7 @@ def fit(self, X, y=None):
size=self.n_components)
return self

def transform(self, X, y=None):
def transform(self, X):
"""Apply the approximate feature map to X.

Parameters
Expand Down Expand Up @@ -278,7 +278,7 @@ def fit(self, X, y=None):
self.sample_interval_ = self.sample_interval
return self

def transform(self, X, y=None):
def transform(self, X):
"""Apply approximate feature map to X.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion sklearn/mixture/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def score(self, X, y=None):
"""
return self.score_samples(X).mean()

def predict(self, X, y=None):
def predict(self, X):
"""Predict the labels for the data samples in X using trained model.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion sklearn/neighbors/approximate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def fit_transform(self, X, y=None):
self.fit(X)
return self.transform(X)

def transform(self, X, y=None):
def transform(self, X):
return self._to_hash(super(ProjectionToHashMixin, self).transform(X))


Expand Down
42 changes: 36 additions & 6 deletions sklearn/preprocessing/_function_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings

from ..base import BaseEstimator, TransformerMixin
from ..utils import check_array
from ..externals.six import string_types


def _identity(X):
Expand Down Expand Up @@ -54,6 +57,8 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
Indicate that transform should forward the y argument to the
inner callable.

.. deprecated::0.19

kw_args : dict, optional
Dictionary of additional keyword arguments to pass to func.

Expand All @@ -62,7 +67,7 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):

"""
def __init__(self, func=None, inverse_func=None, validate=True,
accept_sparse=False, pass_y=False,
accept_sparse=False, pass_y='deprecated',
kw_args=None, inv_kw_args=None):
self.func = func
self.inverse_func = inverse_func
Expand Down Expand Up @@ -90,35 +95,51 @@ def fit(self, X, y=None):
check_array(X, self.accept_sparse)
return self

def transform(self, X, y=None):
def transform(self, X, y='deprecated'):
"""Transform X using the forward function.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Input array.

y : (ignored)
.. deprecated::0.19

Returns
-------
X_out : array-like, shape (n_samples, n_features)
Transformed input.
"""
return self._transform(X, y, self.func, self.kw_args)
if not isinstance(y, string_types) or y != 'deprecated':
warnings.warn("The parameter y on transform() is "
"deprecated since 0.19 and will be removed in 0.21",
DeprecationWarning)

def inverse_transform(self, X, y=None):
return self._transform(X, y=y, func=self.func, kw_args=self.kw_args)

def inverse_transform(self, X, y='deprecated'):
"""Transform X using the inverse function.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Input array.

y : (ignored)
.. deprecated::0.19

Returns
-------
X_out : array-like, shape (n_samples, n_features)
Transformed input.
"""
return self._transform(X, y, self.inverse_func, self.inv_kw_args)
if not isinstance(y, string_types) or y != 'deprecated':
warnings.warn("The parameter y on inverse_transform() is "
"deprecated since 0.19 and will be removed in 0.21",
DeprecationWarning)
return self._transform(X, y=y, func=self.inverse_func,
kw_args=self.inv_kw_args)

def _transform(self, X, y=None, func=None, kw_args=None):
if self.validate:
Expand All @@ -127,5 +148,14 @@ def _transform(self, X, y=None, func=None, kw_args=None):
if func is None:
func = _identity

return func(X, *((y,) if self.pass_y else ()),
if (not isinstance(self.pass_y, string_types) or
self.pass_y != 'deprecated'):
# We do this to know if pass_y was set to False / True
pass_y = self.pass_y
warnings.warn("The parameter pass_y is deprecated since 0.19 and "
"will be removed in 0.21", DeprecationWarning)
else:
pass_y = False

return func(X, *((y,) if pass_y else ()),
**(kw_args if kw_args else {}))
Loading