-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Bug in AdaBoostRegressor with randomstate #7408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
i'm not sure that i understand correctly. are you saying that the performance of your estimator is lower with a fixed If so, be aware that (in certain methods) the results of training are stochastic and depend on The purpose of If my interpretation of your question was incorrect, please clarify. |
I just test with random_state=int(time.time()) 30 rounds average shows that it is worse than not set at all. ( 0.732 MAE vs 0.621) Seems that it happened as well as a value is assigned , not only for a fixed random_state . |
I can imagine this sort of thing coming about by that random state value (rather than sharing a generator, or values generated) being passed to multiple sub-estimators, so that they all have the same randomisation. |
Indeed, that's what's happening: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ensemble/weight_boosting.py#L1001 |
OK, here is a quick test set random_state=int(time.time()) to see the differences. , 16+- vs 49+- from sklearn.datasets import make_regression
import numpy as np
from sklearn.model_selection import KFold,train_test_split
from sklearn.metrics import mean_absolute_error,mean_squared_error
from sklearn.neural_network import MLPRegressor
from sklearn.ensemble import AdaBoostRegressor
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import time
trainx, trainy = make_regression(n_samples=2000, n_features=100, random_state=0, noise=4.0,
bias=2.0)
print trainx[0][0:5]
print trainy[0:5]
def evalTrainData(trainDatax,trainV,random_state=2016,eid=''):
fold=10
scores=[]
pccs=[]
cvcount=0
assert len(trainDatax)==len(trainV)
for roundindex in range(0,3):
skf=KFold(fold,shuffle=True,random_state=random_state+roundindex)
for trainIndex,evalIndex in skf.split(trainDatax):
t1=time.time()
cvTrainx,cvTrainy=trainDatax[trainIndex],trainV[trainIndex]
cvEvalx,cvEvaly=trainDatax[evalIndex],trainV[evalIndex]
scaler=StandardScaler()
cvTrainy=scaler.fit_transform(cvTrainy.reshape(-1, 1)).ravel()
lsvr=getxlf(random_state=random_state)
lsvr.fit(cvTrainx,cvTrainy)
predict=lsvr.predict(cvEvalx)
predict=scaler.inverse_transform(predict.reshape(-1,1)).ravel()
score=mean_absolute_error(cvEvaly,predict)
pcc=np.corrcoef(cvEvaly,predict)[0, 1]
print (cvcount,'MAE',score,'PCC',pcc,time.time()-t1,time.asctime( time.localtime(time.time()) ) ,'Train sahpe:',cvTrainx.shape,'eval sahpe:', cvEvalx.shape)
scores.append(score)
pccs.append(pcc)
cvcount+=1
print ('###',eid,'MAE',np.mean(scores),'PCC',np.mean(pccs))
pca_n_components=100
def getxlf(random_state=2016):
xlf1= Pipeline([
('svd',PCA(n_components=pca_n_components)),
('regressor',AdaBoostRegressor(#random_state=int(time.time()),
base_estimator=MLPRegressor(random_state=random_state,early_stopping=True,max_iter=2000)
,n_estimators=30,learning_rate=0.01)),
])
return xlf1
evalTrainData(trainx,trainy) |
I have a half-completed patch. Will post in a few hours. |
(fixes scikit-learn#7408) ENH add utility to set nested random_state FIX ensure nested random_state is set in ensembles
See #7411. For good or bad, that PR goes beyond the scope of just fixing this issue, so it might take a little time to review and merge. |
Cool, thanks. |
We usually close after the bug is fixed. |
OK, thanks. |
(fixes scikit-learn#7408) FIX ensure nested random_state is set in ensembles
(fixes scikit-learn#7408) FIX ensure nested random_state is set in ensembles
(fixes scikit-learn#7408) FIX ensure nested random_state is set in ensembles
* FIX adaboost estimators not randomising correctly (fixes #7408) FIX ensure nested random_state is set in ensembles * DOC add what's new * Only affect *__random_state, not *_random_state for now * TST More informative assertions for ensemble tests * More specific testing of different random_states
* FIX adaboost estimators not randomising correctly (fixes #7408) FIX ensure nested random_state is set in ensembles * DOC add what's new * Only affect *__random_state, not *_random_state for now * TST More informative assertions for ensemble tests * More specific testing of different random_states
…rn#7411) * FIX adaboost estimators not randomising correctly (fixes scikit-learn#7408) FIX ensure nested random_state is set in ensembles * DOC add what's new * Only affect *__random_state, not *_random_state for now * TST More informative assertions for ensemble tests * More specific testing of different random_states
…rn#7411) * FIX adaboost estimators not randomising correctly (fixes scikit-learn#7408) FIX ensure nested random_state is set in ensembles * DOC add what's new * Only affect *__random_state, not *_random_state for now * TST More informative assertions for ensemble tests * More specific testing of different random_states
…rn#7411) * FIX adaboost estimators not randomising correctly (fixes scikit-learn#7408) FIX ensure nested random_state is set in ensembles * DOC add what's new * Only affect *__random_state, not *_random_state for now * TST More informative assertions for ensemble tests * More specific testing of different random_states
Description
Consider following regressor:
If set random_state to some value , the performance is worse than just ignore it.
I create a project for this problem here.
By the way there is no much differences when set the LinearSVR as base_estimator.
Expected Results
Actual Results
Versions
Linux-4.4.0-31-generic-x86_64-with-Ubuntu-16.04-xenial
('Python', '2.7.12 (default, Jul 1 2016, 15:12:24) \n[GCC 5.4.0 20160609]')
('NumPy', '1.11.1')
('SciPy', '0.17.0')
('Scikit-Learn', '0.18.dev0')
The text was updated successfully, but these errors were encountered: