Skip to content

Bug in AdaBoostRegressor with randomstate #7408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
StevenLOL opened this issue Sep 13, 2016 · 10 comments · Fixed by #7411
Closed

Bug in AdaBoostRegressor with randomstate #7408

StevenLOL opened this issue Sep 13, 2016 · 10 comments · Fixed by #7411
Labels

Comments

@StevenLOL
Copy link

StevenLOL commented Sep 13, 2016

Description

Consider following regressor:

xlf1 = Pipeline([('svd', PCA(n_components=pca_n_components)),
                 ('regressor', AdaBoostRegressor(
                     #random_state=random_state,
                     base_estimator=MLPRegressor(random_state=random_state,
                                                 early_stopping=True,
                                                 max_iter=2000),
                     n_estimators=30,
                     learning_rate=0.01)),
                 ])

If set random_state to some value , the performance is worse than just ignore it.
I create a project for this problem here.

By the way there is no much differences when set the LinearSVR as base_estimator.

Expected Results

Actual Results

Versions

Linux-4.4.0-31-generic-x86_64-with-Ubuntu-16.04-xenial
('Python', '2.7.12 (default, Jul 1 2016, 15:12:24) \n[GCC 5.4.0 20160609]')
('NumPy', '1.11.1')
('SciPy', '0.17.0')
('Scikit-Learn', '0.18.dev0')

@nelson-liu
Copy link
Contributor

nelson-liu commented Sep 13, 2016

i'm not sure that i understand correctly. are you saying that the performance of your estimator is lower with a fixed random_state vs not providing it?

If so, be aware that (in certain methods) the results of training are stochastic and depend on random_state. Thus, in your specific case, setting a random state just happened to (one could say randomly) lead to a slightly worse result than seeding the training routine with the global numpy random state.

The purpose of random_state is to enhance reproducibility, and minor differences are to be expected across different values of random_state (try setting the seed to different values and observe that the performance will increase or decrease).

If my interpretation of your question was incorrect, please clarify.

@StevenLOL
Copy link
Author

I just test with random_state=int(time.time())

30 rounds average shows that it is worse than not set at all. ( 0.732 MAE vs 0.621)

Seems that it happened as well as a value is assigned , not only for a fixed random_state .

@jnothman
Copy link
Member

I can imagine this sort of thing coming about by that random state value (rather than sharing a generator, or values generated) being passed to multiple sub-estimators, so that they all have the same randomisation.

@jnothman
Copy link
Member

@jnothman jnothman added the Bug label Sep 13, 2016
@StevenLOL
Copy link
Author

StevenLOL commented Sep 13, 2016

OK, here is a quick test

set random_state=int(time.time()) to see the differences. , 16+- vs 49+-

from sklearn.datasets import make_regression

import numpy as np

from sklearn.model_selection import KFold,train_test_split

from sklearn.metrics import mean_absolute_error,mean_squared_error

from sklearn.neural_network import MLPRegressor
from sklearn.ensemble import AdaBoostRegressor

from sklearn.decomposition import PCA


from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import time

trainx, trainy = make_regression(n_samples=2000, n_features=100, random_state=0, noise=4.0,
                       bias=2.0)

print trainx[0][0:5]
print trainy[0:5]
def evalTrainData(trainDatax,trainV,random_state=2016,eid=''):
    fold=10

    scores=[]
    pccs=[]
    cvcount=0
    assert len(trainDatax)==len(trainV)

    for roundindex in range(0,3):
        skf=KFold(fold,shuffle=True,random_state=random_state+roundindex)
        for trainIndex,evalIndex in skf.split(trainDatax):
            t1=time.time()
            cvTrainx,cvTrainy=trainDatax[trainIndex],trainV[trainIndex]
            cvEvalx,cvEvaly=trainDatax[evalIndex],trainV[evalIndex]

            scaler=StandardScaler()
            cvTrainy=scaler.fit_transform(cvTrainy.reshape(-1, 1)).ravel()
            lsvr=getxlf(random_state=random_state)
            lsvr.fit(cvTrainx,cvTrainy)
            predict=lsvr.predict(cvEvalx)
            predict=scaler.inverse_transform(predict.reshape(-1,1)).ravel()
            score=mean_absolute_error(cvEvaly,predict)
            pcc=np.corrcoef(cvEvaly,predict)[0, 1]
            print (cvcount,'MAE',score,'PCC',pcc,time.time()-t1,time.asctime( time.localtime(time.time()) ) ,'Train sahpe:',cvTrainx.shape,'eval sahpe:', cvEvalx.shape)
            scores.append(score)
            pccs.append(pcc)
            cvcount+=1

    print ('###',eid,'MAE',np.mean(scores),'PCC',np.mean(pccs))


pca_n_components=100
def getxlf(random_state=2016):
    xlf1= Pipeline([
                          ('svd',PCA(n_components=pca_n_components)),
                          ('regressor',AdaBoostRegressor(#random_state=int(time.time()),

                           base_estimator=MLPRegressor(random_state=random_state,early_stopping=True,max_iter=2000)
                                                         ,n_estimators=30,learning_rate=0.01)),
                          ])

    return xlf1

evalTrainData(trainx,trainy)

@jnothman
Copy link
Member

I have a half-completed patch. Will post in a few hours.

jnothman added a commit to jnothman/scikit-learn that referenced this issue Sep 13, 2016
(fixes scikit-learn#7408)

ENH add utility to set nested random_state

FIX ensure nested random_state is set in ensembles
@jnothman
Copy link
Member

See #7411. For good or bad, that PR goes beyond the scope of just fixing this issue, so it might take a little time to review and merge.

@StevenLOL
Copy link
Author

Cool, thanks.

@jnothman
Copy link
Member

We usually close after the bug is fixed.

@jnothman jnothman reopened this Sep 14, 2016
@StevenLOL
Copy link
Author

OK, thanks.

jnothman added a commit to jnothman/scikit-learn that referenced this issue Sep 14, 2016
(fixes scikit-learn#7408)

FIX ensure nested random_state is set in ensembles
jnothman added a commit to jnothman/scikit-learn that referenced this issue Sep 21, 2016
(fixes scikit-learn#7408)

FIX ensure nested random_state is set in ensembles
jnothman added a commit to jnothman/scikit-learn that referenced this issue Sep 22, 2016
(fixes scikit-learn#7408)

FIX ensure nested random_state is set in ensembles
jnothman added a commit that referenced this issue Sep 23, 2016
* FIX adaboost estimators not randomising correctly

(fixes #7408)

FIX ensure nested random_state is set in ensembles

* DOC add what's new

* Only affect *__random_state, not *_random_state for now

* TST More informative assertions for ensemble tests

* More specific testing of different random_states
amueller pushed a commit that referenced this issue Sep 25, 2016
* FIX adaboost estimators not randomising correctly

(fixes #7408)

FIX ensure nested random_state is set in ensembles

* DOC add what's new

* Only affect *__random_state, not *_random_state for now

* TST More informative assertions for ensemble tests

* More specific testing of different random_states
TomDLT pushed a commit to TomDLT/scikit-learn that referenced this issue Oct 3, 2016
…rn#7411)

* FIX adaboost estimators not randomising correctly

(fixes scikit-learn#7408)

FIX ensure nested random_state is set in ensembles

* DOC add what's new

* Only affect *__random_state, not *_random_state for now

* TST More informative assertions for ensemble tests

* More specific testing of different random_states
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this issue Jun 14, 2017
…rn#7411)

* FIX adaboost estimators not randomising correctly

(fixes scikit-learn#7408)

FIX ensure nested random_state is set in ensembles

* DOC add what's new

* Only affect *__random_state, not *_random_state for now

* TST More informative assertions for ensemble tests

* More specific testing of different random_states
paulha pushed a commit to paulha/scikit-learn that referenced this issue Aug 19, 2017
…rn#7411)

* FIX adaboost estimators not randomising correctly

(fixes scikit-learn#7408)

FIX ensure nested random_state is set in ensembles

* DOC add what's new

* Only affect *__random_state, not *_random_state for now

* TST More informative assertions for ensemble tests

* More specific testing of different random_states
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants