-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Description
Description
GridSearchCV does not work with the new style StratifiedKFold.
Steps/Code to Reproduce
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GridSearchCV
X = np.random.randn(100, 100)
y = np.random.randn(100) > 0
clf = GridSearchCV(
estimator = RandomForestClassifier(),
param_grid = {'n_estimators': [10, 20]},
cv=StratifiedKFold().split(X, y)
)
clf.fit(X, y)
Expected Results
No error is thrown. According to docs value of cv
can be a split fold generator, which StratifiedKFold().split(X, y)
is.
Actual Results
1489 Returns the number of splitting iterations in the cross-validator.
1490 """
-> 1491 return len(self.cv) # Both iterables and old-cv objects support len
1492
1493 def split(self, X=None, y=None, groups=None):
TypeError: object of type 'generator' has no len()
Versions
Linux-4.4.0-38-generic-x86_64-with-Ubuntu-16.04-xenial
('Python', '2.7.12 (default, Jul 1 2016, 15:12:24) \n[GCC 5.4.0 20160609]')
('NumPy', '1.11.1')
('SciPy', '0.18.1')
('Scikit-Learn', '0.18')
Metadata
Metadata
Assignees
Labels
No labels