Skip to content

GridSearchCV does not work with StratifiedKFold (fails to get_n_splits)  #7808

@psarka

Description

@psarka

Description

GridSearchCV does not work with the new style StratifiedKFold.

Steps/Code to Reproduce

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GridSearchCV

X = np.random.randn(100, 100)
y = np.random.randn(100) > 0

clf = GridSearchCV(
    estimator = RandomForestClassifier(),
    param_grid = {'n_estimators': [10, 20]},
    cv=StratifiedKFold().split(X, y)
)

clf.fit(X, y)

Expected Results

No error is thrown. According to docs value of cv can be a split fold generator, which StratifiedKFold().split(X, y) is.

Actual Results

   1489             Returns the number of splitting iterations in the cross-validator.
   1490         """
-> 1491         return len(self.cv)  # Both iterables and old-cv objects support len
   1492 
   1493     def split(self, X=None, y=None, groups=None):

TypeError: object of type 'generator' has no len()

Versions

Linux-4.4.0-38-generic-x86_64-with-Ubuntu-16.04-xenial
('Python', '2.7.12 (default, Jul  1 2016, 15:12:24) \n[GCC 5.4.0 20160609]')
('NumPy', '1.11.1')
('SciPy', '0.18.1')
('Scikit-Learn', '0.18')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions