diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 2792ba8484664..8e43f23af5d63 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1206,6 +1206,7 @@ Model validation preprocessing.QuantileTransformer preprocessing.RobustScaler preprocessing.StandardScaler + preprocessing.TransformedTargetRegressor .. autosummary:: :toctree: generated/ diff --git a/doc/modules/preprocessing_targets.rst b/doc/modules/preprocessing_targets.rst index 88663a55fa0d4..276f15cc52e74 100644 --- a/doc/modules/preprocessing_targets.rst +++ b/doc/modules/preprocessing_targets.rst @@ -1,4 +1,3 @@ - .. currentmodule:: sklearn.preprocessing .. _preprocessing_targets: @@ -7,6 +6,76 @@ Transforming the prediction target (``y``) ========================================== +Transforming target in regression +--------------------------------- + +:class:`TransformedTargetRegressor` transforms the targets ``y`` before fitting a +regression model. The predictions are mapped back to the original space via an +inverse transform. It takes as an argument the regressor that will be used for +prediction, and the transformer that will be applied to the target variable:: + + >>> import numpy as np + >>> from sklearn.datasets import load_boston + >>> from sklearn.preprocessing import (TransformedTargetRegressor, + ... QuantileTransformer) + >>> from sklearn.linear_model import LinearRegression + >>> from sklearn.model_selection import train_test_split + >>> boston = load_boston() + >>> X = boston.data + >>> y = boston.target + >>> transformer = QuantileTransformer(output_distribution='normal') + >>> regressor = LinearRegression() + >>> regr = TransformedTargetRegressor(regressor=regressor, + ... transformer=transformer) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + >>> regr.fit(X_train, y_train) # doctest: +ELLIPSIS + TransformedTargetRegressor(...) + >>> print('R2 score: {0:.2f}'.format(regr.score(X_test, y_test))) + R2 score: 0.67 + >>> raw_target_regr = LinearRegression().fit(X_train, y_train) + >>> print('R2 score: {0:.2f}'.format(raw_target_regr.score(X_test, y_test))) + R2 score: 0.64 + +For simple transformations, instead of a Transformer object, a pair of +functions can be passed, defining the transformation and its inverse mapping:: + + >>> from __future__ import division + >>> def func(x): + ... return np.log(x) + >>> def inverse_func(x): + ... return np.exp(x) + +Subsequently, the object is created as:: + + >>> regr = TransformedTargetRegressor(regressor=regressor, + ... func=func, + ... inverse_func=inverse_func) + >>> regr.fit(X_train, y_train) # doctest: +ELLIPSIS + TransformedTargetRegressor(...) + >>> print('R2 score: {0:.2f}'.format(regr.score(X_test, y_test))) + R2 score: 0.65 + +By default, the provided functions are checked at each fit to be the inverse of +each other. However, it is possible to bypass this checking by setting +``check_inverse`` to ``False``:: + + >>> def inverse_func(x): + ... return x + >>> regr = TransformedTargetRegressor(regressor=regressor, + ... func=func, + ... inverse_func=inverse_func, + ... check_inverse=False) + >>> regr.fit(X_train, y_train) # doctest: +ELLIPSIS + TransformedTargetRegressor(...) + >>> print('R2 score: {0:.2f}'.format(regr.score(X_test, y_test))) + R2 score: -4.50 + +.. note:: + + The transformation can be triggered by setting either ``transformer`` or the + pair of functions ``func`` and ``inverse_func``. However, setting both + options will raise an error. + Label binarization ------------------ diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 000d7ab15135c..caae4c9a1645d 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -77,6 +77,11 @@ Model evaluation - Added :class:`multioutput.RegressorChain` for multi-target regression. :issue:`9257` by :user:`Kumar Ashutosh `. +- Added the :class:`preprocessing.TransformedTargetRegressor` which transforms + the target y before fitting a regression model. The predictions are mapped + back to the original space via an inverse transform. :issue:`9041` by + `Andreas Müller`_ and :user:`Guillaume Lemaitre `. + Enhancements ............ diff --git a/examples/preprocessing/plot_transformed_target.py b/examples/preprocessing/plot_transformed_target.py new file mode 100755 index 0000000000000..c53adecbeb737 --- /dev/null +++ b/examples/preprocessing/plot_transformed_target.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +====================================================== +Effect of transforming the targets in regression model +====================================================== + +In this example, we give an overview of the +:class:`sklearn.preprocessing.TransformedTargetRegressor`. Two examples +illustrate the benefit of transforming the targets before learning a linear +regression model. The first example uses synthetic data while the second +example is based on the Boston housing data set. + +""" + +# Author: Guillaume Lemaitre +# License: BSD 3 clause + +from __future__ import print_function, division + +import numpy as np +import matplotlib.pyplot as plt + +print(__doc__) + +############################################################################### +# Synthetic example +############################################################################### + +from sklearn.datasets import make_regression +from sklearn.model_selection import train_test_split +from sklearn.linear_model import RidgeCV +from sklearn.preprocessing import TransformedTargetRegressor +from sklearn.metrics import median_absolute_error, r2_score + +############################################################################### +# A synthetic random regression problem is generated. The targets ``y`` are +# modified by: (i) translating all targets such that all entries are +# non-negative and (ii) applying an exponential function to obtain non-linear +# targets which cannot be fitted using a simple linear model. +# +# Therefore, a logarithmic and an exponential function will be used to +# transform the targets before training a linear regression model and using it +# for prediction. + + +def log_transform(x): + return np.log(x + 1) + + +def exp_transform(x): + return np.exp(x) - 1 + + +X, y = make_regression(n_samples=10000, noise=100, random_state=0) +y = np.exp((y + abs(y.min())) / 200) +y_trans = log_transform(y) + +############################################################################### +# The following illustrate the probability density functions of the target +# before and after applying the logarithmic functions. + +f, (ax0, ax1) = plt.subplots(1, 2) + +ax0.hist(y, bins='auto', normed=True) +ax0.set_xlim([0, 2000]) +ax0.set_ylabel('Probability') +ax0.set_xlabel('Target') +ax0.set_title('Target distribution') + +ax1.hist(y_trans, bins='auto', normed=True) +ax1.set_ylabel('Probability') +ax1.set_xlabel('Target') +ax1.set_title('Transformed target distribution') + +f.suptitle("Synthetic data", y=0.035) +f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) + +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + +############################################################################### +# At first, a linear model will be applied on the original targets. Due to the +# non-linearity, the model trained will not be precise during the +# prediction. Subsequently, a logarithmic function is used to linearize the +# targets, allowing better prediction even with a similar linear model as +# reported by the median absolute error (MAE). + +f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) + +regr = RidgeCV() +regr.fit(X_train, y_train) +y_pred = regr.predict(X_test) + +ax0.scatter(y_test, y_pred) +ax0.plot([0, 2000], [0, 2000], '--k') +ax0.set_ylabel('Target predicted') +ax0.set_xlabel('True Target') +ax0.set_title('Ridge regression \n without target transformation') +ax0.text(100, 1750, r'$R^2$=%.2f, MAE=%.2f' % ( + r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) +ax0.set_xlim([0, 2000]) +ax0.set_ylim([0, 2000]) + +regr_trans = TransformedTargetRegressor(regressor=RidgeCV(), + func=log_transform, + inverse_func=exp_transform) +regr_trans.fit(X_train, y_train) +y_pred = regr_trans.predict(X_test) + +ax1.scatter(y_test, y_pred) +ax1.plot([0, 2000], [0, 2000], '--k') +ax1.set_ylabel('Target predicted') +ax1.set_xlabel('True Target') +ax1.set_title('Ridge regression \n with target transformation') +ax1.text(100, 1750, r'$R^2$=%.2f, MAE=%.2f' % ( + r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) +ax1.set_xlim([0, 2000]) +ax1.set_ylim([0, 2000]) + +f.suptitle("Synthetic data", y=0.035) +f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) + +############################################################################### +# Real-world data set +############################################################################### + +############################################################################### +# In a similar manner, the boston housing data set is used to show the impact +# of transforming the targets before learning a model. In this example, the +# targets to be predicted corresponds to the weighted distances to the five +# Boston employment centers. + +from sklearn.datasets import load_boston +from sklearn.preprocessing import QuantileTransformer, quantile_transform + +dataset = load_boston() +target = np.array(dataset.feature_names) == "DIS" +X = dataset.data[:, np.logical_not(target)] +y = dataset.data[:, target].squeeze() +y_trans = quantile_transform(dataset.data[:, target], + output_distribution='normal').squeeze() + +############################################################################### +# A :class:`sklearn.preprocessing.QuantileTransformer` is used such that the +# targets follows a normal distribution before applying a +# :class:`sklearn.linear_model.RidgeCV` model. + +f, (ax0, ax1) = plt.subplots(1, 2) + +ax0.hist(y, bins='auto', normed=True) +ax0.set_ylabel('Probability') +ax0.set_xlabel('Target') +ax0.set_title('Target distribution') + +ax1.hist(y_trans, bins='auto', normed=True) +ax1.set_ylabel('Probability') +ax1.set_xlabel('Target') +ax1.set_title('Transformed target distribution') + +f.suptitle("Boston housing data: distance to employment centers", y=0.035) +f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) + +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1) + +############################################################################### +# The effect of the transformer is weaker than on the synthetic data. However, +# the transform induces a decrease of the MAE. + +f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) + +regr = RidgeCV() +regr.fit(X_train, y_train) +y_pred = regr.predict(X_test) + +ax0.scatter(y_test, y_pred) +ax0.plot([0, 10], [0, 10], '--k') +ax0.set_ylabel('Target predicted') +ax0.set_xlabel('True Target') +ax0.set_title('Ridge regression \n without target transformation') +ax0.text(1, 9, r'$R^2$=%.2f, MAE=%.2f' % ( + r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) +ax0.set_xlim([0, 10]) +ax0.set_ylim([0, 10]) + +regr_trans = TransformedTargetRegressor( + regressor=RidgeCV(), + transformer=QuantileTransformer(output_distribution='normal')) +regr_trans.fit(X_train, y_train) +y_pred = regr_trans.predict(X_test) + +ax1.scatter(y_test, y_pred) +ax1.plot([0, 10], [0, 10], '--k') +ax1.set_ylabel('Target predicted') +ax1.set_xlabel('True Target') +ax1.set_title('Ridge regression \n with target transformation') +ax1.text(1, 9, r'$R^2$=%.2f, MAE=%.2f' % ( + r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) +ax1.set_xlim([0, 10]) +ax1.set_ylim([0, 10]) + +f.suptitle("Boston housing data: distance to employment centers", y=0.035) +f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) + +plt.show() diff --git a/sklearn/preprocessing/__init__.py b/sklearn/preprocessing/__init__.py index acb70c93c83c5..0563dd018881f 100644 --- a/sklearn/preprocessing/__init__.py +++ b/sklearn/preprocessing/__init__.py @@ -25,7 +25,6 @@ from .data import OneHotEncoder from .data import PowerTransformer from .data import CategoricalEncoder - from .data import PolynomialFeatures from .label import label_binarize @@ -33,6 +32,7 @@ from .label import LabelEncoder from .label import MultiLabelBinarizer +from ._target import TransformedTargetRegressor from .imputation import Imputer @@ -53,6 +53,7 @@ 'PowerTransformer', 'RobustScaler', 'StandardScaler', + 'TransformedTargetRegressor', 'add_dummy_feature', 'PolynomialFeatures', 'binarize', diff --git a/sklearn/preprocessing/_target.py b/sklearn/preprocessing/_target.py new file mode 100644 index 0000000000000..14c14b1396846 --- /dev/null +++ b/sklearn/preprocessing/_target.py @@ -0,0 +1,223 @@ +# Authors: Andreas Mueller +# Guillaume Lemaitre +# License: BSD 3 clause + +import warnings + +import numpy as np + +from ..base import BaseEstimator, RegressorMixin, clone +from ..utils.validation import check_is_fitted +from ..utils import check_array, safe_indexing +from ._function_transformer import FunctionTransformer + +__all__ = ['TransformedTargetRegressor'] + + +class TransformedTargetRegressor(BaseEstimator, RegressorMixin): + """Meta-estimator to regress on a transformed target. + + Useful for applying a non-linear transformation in regression + problems. This transformation can be given as a Transformer such as the + QuantileTransformer or as a function and its inverse such as ``log`` and + ``exp``. + + The computation during ``fit`` is:: + regressor.fit(X, func(y)) + or:: + regressor.fit(X, transformer.transform(y)) + + The computation during ``predict`` is:: + inverse_func(regressor.predict(X)) + or:: + transformer.inverse_transform(regressor.predict(X)) + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + regressor : object, default=LinearRegression() + Regressor object such as derived from ``RegressorMixin``. This + regressor will automatically be cloned each time prior to fitting. + + transformer : object, default=None + Estimator object such as derived from ``TransformerMixin``. Cannot be + set at the same time as ``func`` and ``inverse_func``. If + ``transformer`` is ``None`` as well as ``func`` and ``inverse_func``, + the transformer will be an identity transformer. Note that the + transformer will be cloned during fitting. Also, the transformer is + restricting ``y`` to be a numpy array. + + func : function, optional + Function to apply to ``y`` before passing to ``fit``. Cannot be set at + the same time as ``transformer``. The function needs to return a + 2-dimensional array. If ``func`` is ``None``, the function used will be + the identity function. + + inverse_func : function, optional + Function to apply to the prediction of the regressor. Cannot be set at + the same time as ``transformer`` as well. The function needs to return + a 2-dimensional array. The inverse function is used to return + predictions to the same space of the original training labels. + + check_inverse : bool, default=True + Whether to check that ``transform`` followed by ``inverse_transform`` + or ``func`` followed by ``inverse_func`` leads to the original targets. + + Attributes + ---------- + regressor_ : object + Fitted regressor. + + transformer_ : object + Transformer used in ``fit`` and ``predict``. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.linear_model import LinearRegression + >>> from sklearn.preprocessing import TransformedTargetRegressor + >>> tt = TransformedTargetRegressor(regressor=LinearRegression(), + ... func=np.log, inverse_func=np.exp) + >>> X = np.arange(4).reshape(-1, 1) + >>> y = np.exp(2 * X).ravel() + >>> tt.fit(X, y) # doctest: +ELLIPSIS + TransformedTargetRegressor(...) + >>> tt.score(X, y) + 1.0 + >>> tt.regressor_.coef_ + array([ 2.]) + + Notes + ----- + Internally, the target ``y`` is always converted into a 2-dimensional array + to be used by scikit-learn transformers. At the time of prediction, the + output will be reshaped to a have the same number of dimensions as ``y``. + + See :ref:`examples/preprocessing/plot_transform_target.py + `. + + """ + def __init__(self, regressor=None, transformer=None, + func=None, inverse_func=None, check_inverse=True): + self.regressor = regressor + self.transformer = transformer + self.func = func + self.inverse_func = inverse_func + self.check_inverse = check_inverse + + def _fit_transformer(self, y): + if (self.transformer is not None and + (self.func is not None or self.inverse_func is not None)): + raise ValueError("'transformer' and functions 'func'/" + "'inverse_func' cannot both be set.") + elif self.transformer is not None: + self.transformer_ = clone(self.transformer) + else: + if self.func is not None and self.inverse_func is None: + raise ValueError("When 'func' is provided, 'inverse_func' must" + " also be provided") + self.transformer_ = FunctionTransformer( + func=self.func, inverse_func=self.inverse_func, validate=True, + check_inverse=self.check_inverse) + # XXX: sample_weight is not currently passed to the + # transformer. However, if transformer starts using sample_weight, the + # code should be modified accordingly. At the time to consider the + # sample_prop feature, it is also a good use case to be considered. + self.transformer_.fit(y) + if self.check_inverse: + idx_selected = slice(None, None, max(1, y.shape[0] // 10)) + y_sel = safe_indexing(y, idx_selected) + y_sel_t = self.transformer_.transform(y_sel) + if not np.allclose(y_sel, + self.transformer_.inverse_transform(y_sel_t)): + warnings.warn("The provided functions or transformer are" + " not strictly inverse of each other. If" + " you are sure you want to proceed regardless" + ", set 'check_inverse=False'", UserWarning) + + def fit(self, X, y, sample_weight=None): + """Fit the model according to the given training data. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape (n_samples,) + Target values. + + sample_weight : array-like, shape (n_samples,) optional + Array of weights that are assigned to individual samples. + If not provided, then each sample is given unit weight. + + Returns + ------- + self : object + Returns self. + """ + y = check_array(y, accept_sparse=False, force_all_finite=True, + ensure_2d=False, dtype='numeric') + + # store the number of dimension of the target to predict an array of + # similar shape at predict + self._training_dim = y.ndim + + # transformers are designed to modify X which is 2d dimensional, we + # need to modify y accordingly. + if y.ndim == 1: + y_2d = y.reshape(-1, 1) + else: + y_2d = y + self._fit_transformer(y_2d) + + if self.regressor is None: + from ..linear_model import LinearRegression + self.regressor_ = LinearRegression() + else: + self.regressor_ = clone(self.regressor) + + # transform y and convert back to 1d array if needed + y_trans = self.transformer_.fit_transform(y_2d) + # FIXME: a FunctionTransformer can return a 1D array even when validate + # is set to True. Therefore, we need to check the number of dimension + # first. + if y_trans.ndim == 2 and y_trans.shape[1] == 1: + y_trans = y_trans.squeeze(axis=1) + if sample_weight is None: + self.regressor_.fit(X, y_trans) + else: + self.regressor_.fit(X, y_trans, sample_weight=sample_weight) + + return self + + def predict(self, X): + """Predict using the base regressor, applying inverse. + + The regressor is used to predict and the ``inverse_func`` or + ``inverse_transform`` is applied before returning the prediction. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = (n_samples, n_features) + Samples. + + Returns + ------- + y_hat : array, shape = (n_samples,) + Predicted values. + + """ + check_is_fitted(self, "regressor_") + pred = self.regressor_.predict(X) + if pred.ndim == 1: + pred_trans = self.transformer_.inverse_transform( + pred.reshape(-1, 1)) + else: + pred_trans = self.transformer_.inverse_transform(pred) + if (self._training_dim == 1 and + pred_trans.ndim == 2 and pred_trans.shape[1] == 1): + pred_trans = pred_trans.squeeze(axis=1) + + return pred_trans diff --git a/sklearn/preprocessing/tests/test_target.py b/sklearn/preprocessing/tests/test_target.py new file mode 100644 index 0000000000000..a385f9e13b37a --- /dev/null +++ b/sklearn/preprocessing/tests/test_target.py @@ -0,0 +1,266 @@ +import numpy as np +import pytest + +from sklearn.base import clone +from sklearn.base import BaseEstimator +from sklearn.base import TransformerMixin + +from sklearn.dummy import DummyRegressor + +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_raises_regex +from sklearn.utils.testing import assert_allclose +from sklearn.utils.testing import assert_warns_message +from sklearn.utils.testing import assert_no_warnings + +from sklearn.preprocessing import FunctionTransformer +from sklearn.preprocessing import TransformedTargetRegressor +from sklearn.preprocessing import StandardScaler + +from sklearn.linear_model import LinearRegression, Lasso + +from sklearn import datasets + +friedman = datasets.make_friedman1(random_state=0) + + +def test_transform_target_regressor_error(): + X, y = friedman + # provide a transformer and functions at the same time + regr = TransformedTargetRegressor(regressor=LinearRegression(), + transformer=StandardScaler(), + func=np.exp, inverse_func=np.log) + assert_raises_regex(ValueError, "'transformer' and functions" + " 'func'/'inverse_func' cannot both be set.", + regr.fit, X, y) + # fit with sample_weight with a regressor which does not support it + sample_weight = np.ones((y.shape[0],)) + regr = TransformedTargetRegressor(regressor=Lasso(), + transformer=StandardScaler()) + assert_raises_regex(TypeError, "fit\(\) got an unexpected keyword argument" + " 'sample_weight'", regr.fit, X, y, + sample_weight=sample_weight) + # func is given but inverse_func is not + regr = TransformedTargetRegressor(func=np.exp) + assert_raises_regex(ValueError, "When 'func' is provided, 'inverse_func'" + " must also be provided", regr.fit, X, y) + + +def test_transform_target_regressor_invertible(): + X, y = friedman + regr = TransformedTargetRegressor(regressor=LinearRegression(), + func=np.sqrt, inverse_func=np.log, + check_inverse=True) + assert_warns_message(UserWarning, "The provided functions or transformer" + " are not strictly inverse of each other.", + regr.fit, X, y) + regr = TransformedTargetRegressor(regressor=LinearRegression(), + func=np.sqrt, inverse_func=np.log) + regr.set_params(check_inverse=False) + assert_no_warnings(regr.fit, X, y) + + +def _check_standard_scaled(y, y_pred): + y_mean = np.mean(y, axis=0) + y_std = np.std(y, axis=0) + assert_allclose((y - y_mean) / y_std, y_pred) + + +def _check_shifted_by_one(y, y_pred): + assert_allclose(y + 1, y_pred) + + +def test_transform_target_regressor_functions(): + X, y = friedman + regr = TransformedTargetRegressor(regressor=LinearRegression(), + func=np.log, inverse_func=np.exp) + y_pred = regr.fit(X, y).predict(X) + # check the transformer output + y_tran = regr.transformer_.transform(y.reshape(-1, 1)).squeeze() + assert_allclose(np.log(y), y_tran) + assert_allclose(y, regr.transformer_.inverse_transform( + y_tran.reshape(-1, 1)).squeeze()) + assert y.shape == y_pred.shape + assert_allclose(y_pred, regr.inverse_func(regr.regressor_.predict(X))) + # check the regressor output + lr = LinearRegression().fit(X, regr.func(y)) + assert_allclose(regr.regressor_.coef_.ravel(), lr.coef_.ravel()) + + +def test_transform_target_regressor_functions_multioutput(): + X = friedman[0] + y = np.vstack((friedman[1], friedman[1] ** 2 + 1)).T + regr = TransformedTargetRegressor(regressor=LinearRegression(), + func=np.log, inverse_func=np.exp) + y_pred = regr.fit(X, y).predict(X) + # check the transformer output + y_tran = regr.transformer_.transform(y) + assert_allclose(np.log(y), y_tran) + assert_allclose(y, regr.transformer_.inverse_transform(y_tran)) + assert y.shape == y_pred.shape + assert_allclose(y_pred, regr.inverse_func(regr.regressor_.predict(X))) + # check the regressor output + lr = LinearRegression().fit(X, regr.func(y)) + assert_allclose(regr.regressor_.coef_.ravel(), lr.coef_.ravel()) + + +@pytest.mark.parametrize("X,y", [friedman, + (friedman[0], + np.vstack((friedman[1], + friedman[1] ** 2 + 1)).T)]) +def test_transform_target_regressor_1d_transformer(X, y): + # All transformer in scikit-learn expect 2D data. FunctionTransformer with + # validate=False lift this constraint without checking that the input is a + # 2D vector. We check the consistency of the data shape using a 1D and 2D y + # array. + transformer = FunctionTransformer(func=lambda x: x + 1, + inverse_func=lambda x: x - 1, + validate=False) + regr = TransformedTargetRegressor(regressor=LinearRegression(), + transformer=transformer) + y_pred = regr.fit(X, y).predict(X) + assert y.shape == y_pred.shape + # consistency forward transform + y_tran = regr.transformer_.transform(y) + _check_shifted_by_one(y, y_tran) + assert y.shape == y_pred.shape + # consistency inverse transform + assert_allclose(y, regr.transformer_.inverse_transform( + y_tran).squeeze()) + # consistency of the regressor + lr = LinearRegression() + transformer2 = clone(transformer) + lr.fit(X, transformer2.fit_transform(y)) + y_lr_pred = lr.predict(X) + assert_allclose(y_pred, transformer2.inverse_transform(y_lr_pred)) + assert_allclose(regr.regressor_.coef_, lr.coef_) + + +@pytest.mark.parametrize("X,y", [friedman, + (friedman[0], + np.vstack((friedman[1], + friedman[1] ** 2 + 1)).T)]) +def test_transform_target_regressor_2d_transformer(X, y): + # Check consistency with transformer accepting only 2D array and a 1D/2D y + # array. + transformer = StandardScaler() + regr = TransformedTargetRegressor(regressor=LinearRegression(), + transformer=transformer) + y_pred = regr.fit(X, y).predict(X) + assert y.shape == y_pred.shape + # consistency forward transform + if y.ndim == 1: # create a 2D array and squeeze results + y_tran = regr.transformer_.transform(y.reshape(-1, 1)).squeeze() + else: + y_tran = regr.transformer_.transform(y) + _check_standard_scaled(y, y_tran) + assert y.shape == y_pred.shape + # consistency inverse transform + assert_allclose(y, regr.transformer_.inverse_transform( + y_tran).squeeze()) + # consistency of the regressor + lr = LinearRegression() + transformer2 = clone(transformer) + if y.ndim == 1: # create a 2D array and squeeze results + lr.fit(X, transformer2.fit_transform(y.reshape(-1, 1)).squeeze()) + else: + lr.fit(X, transformer2.fit_transform(y)) + y_lr_pred = lr.predict(X) + assert_allclose(y_pred, transformer2.inverse_transform(y_lr_pred)) + assert_allclose(regr.regressor_.coef_, lr.coef_) + + +def test_transform_target_regressor_2d_transformer_multioutput(): + # Check consistency with transformer accepting only 2D array and a 2D y + # array. + X = friedman[0] + y = np.vstack((friedman[1], friedman[1] ** 2 + 1)).T + transformer = StandardScaler() + regr = TransformedTargetRegressor(regressor=LinearRegression(), + transformer=transformer) + y_pred = regr.fit(X, y).predict(X) + assert y.shape == y_pred.shape + # consistency forward transform + y_tran = regr.transformer_.transform(y) + _check_standard_scaled(y, y_tran) + assert y.shape == y_pred.shape + # consistency inverse transform + assert_allclose(y, regr.transformer_.inverse_transform( + y_tran).squeeze()) + # consistency of the regressor + lr = LinearRegression() + transformer2 = clone(transformer) + lr.fit(X, transformer2.fit_transform(y)) + y_lr_pred = lr.predict(X) + assert_allclose(y_pred, transformer2.inverse_transform(y_lr_pred)) + assert_allclose(regr.regressor_.coef_, lr.coef_) + + +def test_transform_target_regressor_multi_to_single(): + X = friedman[0] + y = np.transpose([friedman[1], (friedman[1] ** 2 + 1)]) + + def func(y): + out = np.sqrt(y[:, 0] ** 2 + y[:, 1] ** 2) + return out[:, np.newaxis] + + def inverse_func(y): + return y + + tt = TransformedTargetRegressor(func=func, inverse_func=inverse_func, + check_inverse=False) + tt.fit(X, y) + y_pred_2d_func = tt.predict(X) + assert y_pred_2d_func.shape == (100, 1) + + # force that the function only return a 1D array + def func(y): + return np.sqrt(y[:, 0] ** 2 + y[:, 1] ** 2) + + tt = TransformedTargetRegressor(func=func, inverse_func=inverse_func, + check_inverse=False) + tt.fit(X, y) + y_pred_1d_func = tt.predict(X) + assert y_pred_1d_func.shape == (100, 1) + + assert_allclose(y_pred_1d_func, y_pred_2d_func) + + +class DummyCheckerArrayTransformer(BaseEstimator, TransformerMixin): + + def fit(self, X, y=None): + assert isinstance(X, np.ndarray) + return self + + def transform(self, X): + assert isinstance(X, np.ndarray) + return X + + def inverse_transform(self, X): + assert isinstance(X, np.ndarray) + return X + + +class DummyCheckerListRegressor(DummyRegressor): + + def fit(self, X, y, sample_weight=None): + assert isinstance(X, list) + return super(DummyCheckerListRegressor, self).fit(X, y, sample_weight) + + def predict(self, X): + assert isinstance(X, list) + return super(DummyCheckerListRegressor, self).predict(X) + + +def test_transform_target_regressor_ensure_y_array(): + # check that the target ``y`` passed to the transformer will always be a + # numpy array. Similarly, if ``X`` is passed as a list, we check that the + # predictor receive as it is. + X, y = friedman + tt = TransformedTargetRegressor(transformer=DummyCheckerArrayTransformer(), + regressor=DummyCheckerListRegressor(), + check_inverse=False) + tt.fit(X.tolist(), y.tolist()) + tt.predict(X.tolist()) + assert_raises(AssertionError, tt.fit, X, y.tolist()) + assert_raises(AssertionError, tt.predict, X) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 281ddc3e708d5..398c12cbddb42 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -63,7 +63,7 @@ CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'] MULTI_OUTPUT = ['CCA', 'DecisionTreeRegressor', 'ElasticNet', 'ExtraTreeRegressor', 'ExtraTreesRegressor', 'GaussianProcess', - 'GaussianProcessRegressor', + 'GaussianProcessRegressor', 'TransformedTargetRegressor', 'KNeighborsRegressor', 'KernelRidge', 'Lars', 'Lasso', 'LassoLars', 'LinearRegression', 'MultiTaskElasticNet', 'MultiTaskElasticNetCV', 'MultiTaskLasso', 'MultiTaskLassoCV',