-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+2] TransformedTargetRegressor #9041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
83 commits
Select commit
Hold shift + click to select a range
21a2ff3
implement target transformer
amueller 7b1f7e8
make example use log and ex
amueller b306fa5
some docstrings
amueller 7d9badf
EHN/TST advance TTR
glemaitre 97da7a3
FIX call fit of the transformer at validation time
glemaitre 61a543a
EHN/TST ravel y when needed
glemaitre de8dbb4
FIX address comment Andy
glemaitre 254fac2
EHN add LinearRegression
glemaitre 53c7c81
Merge remote-tracking branch 'origin/master' into targettransformer
glemaitre 693de84
EHN move to target file
glemaitre 3dafc8f
FIX/EXA fix example in the docstring
glemaitre 27f1c43
Merge remote-tracking branch 'origin/master' into targettransformer
glemaitre 73bbcaf
ENH address comments
glemaitre e6a4e7d
Merge remote-tracking branch 'origin/master' into targettransformer
glemaitre 503a985
DOC narrative doc for ttr
glemaitre 63dbe9a
DOC update whats new and docstring
glemaitre 9feafda
Update whats new
glemaitre 32a85a6
Remove useless changes
glemaitre af51cf8
Update whats new
glemaitre dcae366
address comments
glemaitre d8310ad
Merge branch 'targettransformer' of github.com:glemaitre/scikit-learn…
glemaitre 4c3ab11
DOC change to bostong dataset
glemaitre 49ea3c4
Remove score
glemaitre f1a7289
add the estimator to multioutput
glemaitre ffe6892
Rename the class
glemaitre 2a868ee
gael comments
glemaitre 18c66c6
revert example
glemaitre 85a8865
FIX docstring and commont test
glemaitre 7a10796
FIX solve issue circular import
glemaitre 086fba0
FIX circular import
glemaitre 44ea999
Merge remote-tracking branch 'origin/master' into targettransformer
glemaitre 6c4734e
DOC/FIX/TST test 1d/2d y support transformer and vlad comments
glemaitre 3ecde9f
TST apply change of manoj
glemaitre 5e7d6c9
TST apply change of manoj
glemaitre db4bf57
TST factorize single- multi-output test
glemaitre 0fe1622
FIX ensure at least 1d array
glemaitre 8b94056
TST add test for support of sample weight
glemaitre 01d94e2
EXA add example
glemaitre a0b84c4
DOC fix
glemaitre 437dfaa
DOC revert author
glemaitre 36968ba
FIX minor fix in example and doc
glemaitre d253fcd
DOC fixes
glemaitre 451dfd3
FIX remove useless import
glemaitre 19a6f94
Remove absolute tolerance
glemaitre 9e07197
Merge branch 'master' into targettransformer
glemaitre 0ddfee0
TST single to multi and regressor checking
glemaitre 8392cc5
pass sample_weight directly
glemaitre a0bf0b0
PEP8
glemaitre 18bcec0
use is_regressor instead of tag
glemaitre f3e151f
Merge branch 'master' into targettransformer
glemaitre 85cc14c
TST split tests
glemaitre 075bf92
Merge remote-tracking branch 'glemaitre/targettransformer' into targe…
glemaitre 51583c2
TST fix multi to single test
glemaitre 9853552
solve the issue if the function are not 2d
glemaitre a1998fa
DOC update docstring
glemaitre 7c8c0ca
DOC fix docstring
glemaitre 97330b0
Merge remote-tracking branch 'origin/master' into targettransformer
glemaitre 7f13b9a
TST check compatibility 1D 2D fuction even if supposidely not supported
glemaitre ae973f8
TST relax equality in prediction
glemaitre 129373d
TST remove single to multi case
glemaitre 9064f24
Address olivier and johel comments
glemaitre 3d80728
not enforcing regressor
glemaitre 5f9db73
Renamed to TransformedTargetRegressor
glemaitre 58c5506
DOC reformat plot titles
ogrisel d0f83fa
change naming functions
glemaitre 4e61395
DOC fixing title
glemaitre 687703b
Merge remote-tracking branch 'origin/master' into targettransformer
glemaitre 500a77c
DOC fix merge git mess
glemaitre 35cb75d
TST/EHN only squeeze when ndim == 1
glemaitre 9a939f3
TST forgot to call fit
glemaitre 00e6d78
Merge remote-tracking branch 'origin/master' into targettransformer
glemaitre 04dc4a7
FIX pass check_inverse to FunctionTransformer
glemaitre f757c10
DOC remove blank lines
glemaitre 214fde6
Add comments and lift constraint upon X
glemaitre 68c5b7e
avoid type conversion since this is in check_array
glemaitre 9976ace
Merge branch 'master' into targettransformer
glemaitre 3c99cde
TST check that y is always converted to array before transformer call
glemaitre 0b364f6
reverse right of plot_ols
glemaitre 64f5d52
address joels comments
glemaitre 5929f81
Merge branch 'master' into targettransformer
glemaitre d637038
MAINT rename module name
glemaitre 790c86a
DOC fix indent
glemaitre bbee2be
FIX add the new module
glemaitre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
====================================================== | ||
Effect of transforming the targets in regression model | ||
====================================================== | ||
|
||
In this example, we give an overview of the | ||
:class:`sklearn.preprocessing.TransformedTargetRegressor`. Two examples | ||
illustrate the benefit of transforming the targets before learning a linear | ||
regression model. The first example uses synthetic data while the second | ||
example is based on the Boston housing data set. | ||
|
||
""" | ||
|
||
# Author: Guillaume Lemaitre <guillaume.lemaitre@inria.fr> | ||
# License: BSD 3 clause | ||
|
||
from __future__ import print_function, division | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
print(__doc__) | ||
|
||
############################################################################### | ||
# Synthetic example | ||
############################################################################### | ||
|
||
from sklearn.datasets import make_regression | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.linear_model import RidgeCV | ||
from sklearn.preprocessing import TransformedTargetRegressor | ||
from sklearn.metrics import median_absolute_error, r2_score | ||
|
||
############################################################################### | ||
# A synthetic random regression problem is generated. The targets ``y`` are | ||
# modified by: (i) translating all targets such that all entries are | ||
# non-negative and (ii) applying an exponential function to obtain non-linear | ||
# targets which cannot be fitted using a simple linear model. | ||
# | ||
# Therefore, a logarithmic and an exponential function will be used to | ||
# transform the targets before training a linear regression model and using it | ||
# for prediction. | ||
|
||
|
||
def log_transform(x): | ||
return np.log(x + 1) | ||
|
||
|
||
def exp_transform(x): | ||
return np.exp(x) - 1 | ||
|
||
|
||
X, y = make_regression(n_samples=10000, noise=100, random_state=0) | ||
y = np.exp((y + abs(y.min())) / 200) | ||
y_trans = log_transform(y) | ||
|
||
############################################################################### | ||
# The following illustrate the probability density functions of the target | ||
# before and after applying the logarithmic functions. | ||
|
||
f, (ax0, ax1) = plt.subplots(1, 2) | ||
|
||
ax0.hist(y, bins='auto', normed=True) | ||
ax0.set_xlim([0, 2000]) | ||
ax0.set_ylabel('Probability') | ||
ax0.set_xlabel('Target') | ||
ax0.set_title('Target distribution') | ||
|
||
ax1.hist(y_trans, bins='auto', normed=True) | ||
ax1.set_ylabel('Probability') | ||
ax1.set_xlabel('Target') | ||
ax1.set_title('Transformed target distribution') | ||
|
||
f.suptitle("Synthetic data", y=0.035) | ||
f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) | ||
|
||
############################################################################### | ||
# At first, a linear model will be applied on the original targets. Due to the | ||
# non-linearity, the model trained will not be precise during the | ||
# prediction. Subsequently, a logarithmic function is used to linearize the | ||
# targets, allowing better prediction even with a similar linear model as | ||
# reported by the median absolute error (MAE). | ||
|
||
f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) | ||
|
||
regr = RidgeCV() | ||
regr.fit(X_train, y_train) | ||
y_pred = regr.predict(X_test) | ||
|
||
ax0.scatter(y_test, y_pred) | ||
ax0.plot([0, 2000], [0, 2000], '--k') | ||
ax0.set_ylabel('Target predicted') | ||
ax0.set_xlabel('True Target') | ||
ax0.set_title('Ridge regression \n without target transformation') | ||
ax0.text(100, 1750, r'$R^2$=%.2f, MAE=%.2f' % ( | ||
r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) | ||
ax0.set_xlim([0, 2000]) | ||
ax0.set_ylim([0, 2000]) | ||
|
||
regr_trans = TransformedTargetRegressor(regressor=RidgeCV(), | ||
func=log_transform, | ||
inverse_func=exp_transform) | ||
regr_trans.fit(X_train, y_train) | ||
y_pred = regr_trans.predict(X_test) | ||
|
||
ax1.scatter(y_test, y_pred) | ||
ax1.plot([0, 2000], [0, 2000], '--k') | ||
ax1.set_ylabel('Target predicted') | ||
ax1.set_xlabel('True Target') | ||
ax1.set_title('Ridge regression \n with target transformation') | ||
ax1.text(100, 1750, r'$R^2$=%.2f, MAE=%.2f' % ( | ||
r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) | ||
ax1.set_xlim([0, 2000]) | ||
ax1.set_ylim([0, 2000]) | ||
|
||
f.suptitle("Synthetic data", y=0.035) | ||
f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) | ||
|
||
############################################################################### | ||
# Real-world data set | ||
############################################################################### | ||
|
||
############################################################################### | ||
# In a similar manner, the boston housing data set is used to show the impact | ||
# of transforming the targets before learning a model. In this example, the | ||
# targets to be predicted corresponds to the weighted distances to the five | ||
# Boston employment centers. | ||
|
||
from sklearn.datasets import load_boston | ||
from sklearn.preprocessing import QuantileTransformer, quantile_transform | ||
|
||
dataset = load_boston() | ||
target = np.array(dataset.feature_names) == "DIS" | ||
X = dataset.data[:, np.logical_not(target)] | ||
y = dataset.data[:, target].squeeze() | ||
y_trans = quantile_transform(dataset.data[:, target], | ||
output_distribution='normal').squeeze() | ||
|
||
############################################################################### | ||
# A :class:`sklearn.preprocessing.QuantileTransformer` is used such that the | ||
# targets follows a normal distribution before applying a | ||
# :class:`sklearn.linear_model.RidgeCV` model. | ||
|
||
f, (ax0, ax1) = plt.subplots(1, 2) | ||
|
||
ax0.hist(y, bins='auto', normed=True) | ||
ax0.set_ylabel('Probability') | ||
ax0.set_xlabel('Target') | ||
ax0.set_title('Target distribution') | ||
|
||
ax1.hist(y_trans, bins='auto', normed=True) | ||
ax1.set_ylabel('Probability') | ||
ax1.set_xlabel('Target') | ||
ax1.set_title('Transformed target distribution') | ||
|
||
f.suptitle("Boston housing data: distance to employment centers", y=0.035) | ||
f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1) | ||
|
||
############################################################################### | ||
# The effect of the transformer is weaker than on the synthetic data. However, | ||
# the transform induces a decrease of the MAE. | ||
|
||
f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) | ||
|
||
regr = RidgeCV() | ||
regr.fit(X_train, y_train) | ||
y_pred = regr.predict(X_test) | ||
|
||
ax0.scatter(y_test, y_pred) | ||
ax0.plot([0, 10], [0, 10], '--k') | ||
ax0.set_ylabel('Target predicted') | ||
ax0.set_xlabel('True Target') | ||
ax0.set_title('Ridge regression \n without target transformation') | ||
ax0.text(1, 9, r'$R^2$=%.2f, MAE=%.2f' % ( | ||
r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) | ||
ax0.set_xlim([0, 10]) | ||
ax0.set_ylim([0, 10]) | ||
|
||
regr_trans = TransformedTargetRegressor( | ||
regressor=RidgeCV(), | ||
transformer=QuantileTransformer(output_distribution='normal')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
regr_trans.fit(X_train, y_train) | ||
y_pred = regr_trans.predict(X_test) | ||
|
||
ax1.scatter(y_test, y_pred) | ||
ax1.plot([0, 10], [0, 10], '--k') | ||
ax1.set_ylabel('Target predicted') | ||
ax1.set_xlabel('True Target') | ||
ax1.set_title('Ridge regression \n with target transformation') | ||
ax1.text(1, 9, r'$R^2$=%.2f, MAE=%.2f' % ( | ||
r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred))) | ||
ax1.set_xlim([0, 10]) | ||
ax1.set_ylim([0, 10]) | ||
|
||
f.suptitle("Boston housing data: distance to employment centers", y=0.035) | ||
f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) | ||
|
||
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason this had disappeared from what's new and I've just reinserted it :\