-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG + 2] FIX Be robust to non re-entrant/ non deterministic cv.split calls #7660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7067018
eab1fe7
f358092
494c52c
87931ce
b996857
9300390
36ec4f9
b95e537
0ea6a4c
59b9f6c
63acab0
1af6f37
f485c38
2528271
1054b3a
59c5bf6
2670f0d
5e735e1
756bd53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1467,7 +1467,7 @@ def get_n_splits(self, X=None, y=None, groups=None): | |
class _CVIterableWrapper(BaseCrossValidator): | ||
"""Wrapper class for old style cv objects and iterables.""" | ||
def __init__(self, cv): | ||
self.cv = cv | ||
self.cv = list(cv) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want this change? This would mean For instance import numpy as np
from sklearn.model_selection import KFold
from sklearn.model_selection import check_cv
from sklearn.datasets import make_classification
from sklearn.utils.testing import assert_array_equal
X, y = make_classification(random_state=0)
# At master - This test will pass ------------------------------------------
kf_iter = KFold(n_splits=5).split(X, y)
kf_iter_wrapped = check_cv(kf_iter)
# First split call will work
list(kf_iter_wrapped.split(X, y))
# But not subsequent
assert_array_equal(list(kf_iter_wrapped.split(X, y)), [])
# At this PR - This test will pass ------------------------------------------
kf_iter = KFold(n_splits=5).split(X, y)
kf_iter_wrapped = check_cv(kf_iter)
# Since the wrapped iterable is enlisted and stored,
# split can be called any number of times to produce
# consistent results
assert_array_equal(
list(kf_iter_wrapped.split(X, y)),
list(kf_iter_wrapped.split(X, y))) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this change is fine. This is only used internally, right? All sklearn functions worked with a single-run iterator before, and now they will do that again after this PR. So there are no behavior changes, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok thanks for the comment... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Just to note that I've added the 2nd test just to ensure this change has effect and to fail if reverted later...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose infinite CV strategies are unlikely. Expensive ones are a bit more likely, but I think this approach is okay. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
def get_n_splits(self, X=None, y=None, groups=None): | ||
"""Returns the number of splitting iterations in the cross-validator | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
|
||
""" | ||
The :mod:`sklearn.model_selection._validation` module includes classes and | ||
functions to validate the model. | ||
|
@@ -129,6 +128,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, | |
X, y, groups = indexable(X, y, groups) | ||
|
||
cv = check_cv(cv, y, classifier=is_classifier(estimator)) | ||
cv_iter = list(cv.split(X, y, groups)) | ||
scorer = check_scoring(estimator, scoring=scoring) | ||
# We clone the estimator to make sure that all the folds are | ||
# independent, and that it is pickle-able. | ||
|
@@ -137,7 +137,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, | |
scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, | ||
train, test, verbose, None, | ||
fit_params) | ||
for train, test in cv.split(X, y, groups)) | ||
for train, test in cv_iter) | ||
return np.array(scores)[:, 0] | ||
|
||
|
||
|
@@ -384,6 +384,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, | |
X, y, groups = indexable(X, y, groups) | ||
|
||
cv = check_cv(cv, y, classifier=is_classifier(estimator)) | ||
cv_iter = list(cv.split(X, y, groups)) | ||
|
||
# Ensure the estimator has implemented the passed decision function | ||
if not callable(getattr(estimator, method)): | ||
|
@@ -396,7 +397,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, | |
pre_dispatch=pre_dispatch) | ||
prediction_blocks = parallel(delayed(_fit_and_predict)( | ||
clone(estimator), X, y, train, test, verbose, fit_params, method) | ||
for train, test in cv.split(X, y, groups)) | ||
for train, test in cv_iter) | ||
|
||
# Concatenate the predictions | ||
predictions = [pred_block_i for pred_block_i, _ in prediction_blocks] | ||
|
@@ -750,9 +751,8 @@ def learning_curve(estimator, X, y, groups=None, | |
X, y, groups = indexable(X, y, groups) | ||
|
||
cv = check_cv(cv, y, classifier=is_classifier(estimator)) | ||
cv_iter = cv.split(X, y, groups) | ||
# Make a list since we will be iterating multiple times over the folds | ||
cv_iter = list(cv_iter) | ||
cv_iter = list(cv.split(X, y, groups)) | ||
scorer = check_scoring(estimator, scoring=scoring) | ||
|
||
n_max_training_samples = len(cv_iter[0][0]) | ||
|
@@ -775,9 +775,8 @@ def learning_curve(estimator, X, y, groups=None, | |
if exploit_incremental_learning: | ||
classes = np.unique(y) if is_classifier(estimator) else None | ||
out = parallel(delayed(_incremental_fit_estimator)( | ||
clone(estimator), X, y, classes, train, | ||
test, train_sizes_abs, scorer, verbose) | ||
for train, test in cv_iter) | ||
clone(estimator), X, y, classes, train, test, train_sizes_abs, | ||
scorer, verbose) for train, test in cv_iter) | ||
else: | ||
train_test_proportions = [] | ||
for train, test in cv_iter: | ||
|
@@ -961,6 +960,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, | |
X, y, groups = indexable(X, y, groups) | ||
|
||
cv = check_cv(cv, y, classifier=is_classifier(estimator)) | ||
cv_iter = list(cv.split(X, y, groups)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too: wouldn't a copy of the cv object have done the trick. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think here the list can be safely removed without side effects as the iteration is only once and is based on splits... |
||
|
||
scorer = check_scoring(estimator, scoring=scoring) | ||
|
||
|
@@ -969,7 +969,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, | |
out = parallel(delayed(_fit_and_score)( | ||
estimator, X, y, scorer, train, test, verbose, | ||
parameters={param_name: v}, fit_params=None, return_train_score=True) | ||
for train, test in cv.split(X, y, groups) for v in param_range) | ||
for train, test in cv_iter for v in param_range) | ||
|
||
out = np.asarray(out) | ||
n_params = len(param_range) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" | ||
Common utilities for testing model selection. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from sklearn.model_selection import KFold | ||
|
||
|
||
class OneTimeSplitter: | ||
"""A wrapper to make KFold single entry cv iterator""" | ||
def __init__(self, n_splits=4, n_samples=99): | ||
self.n_splits = n_splits | ||
self.n_samples = n_samples | ||
self.indices = iter(KFold(n_splits=n_splits).split(np.ones(n_samples))) | ||
|
||
def split(self, X=None, y=None, groups=None): | ||
"""Split can be called only once""" | ||
for index in self.indices: | ||
yield index | ||
|
||
def get_n_splits(self, X=None, y=None, groups=None): | ||
return self.n_splits |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,8 @@ | |
from sklearn.pipeline import Pipeline | ||
from sklearn.linear_model import SGDClassifier | ||
|
||
from sklearn.model_selection.tests.common import OneTimeSplitter | ||
|
||
|
||
# Neither of the following two estimators inherit from BaseEstimator, | ||
# to test hyperparameter search on user-defined classifiers. | ||
|
@@ -1154,3 +1156,58 @@ def test_search_train_scores_set_to_false(): | |
gs = GridSearchCV(clf, param_grid={'C': [0.1, 0.2]}, | ||
return_train_score=False) | ||
gs.fit(X, y) | ||
|
||
|
||
def test_grid_search_cv_splits_consistency(): | ||
# Check if a one time iterable is accepted as a cv parameter. | ||
n_samples = 100 | ||
n_splits = 5 | ||
X, y = make_classification(n_samples=n_samples, random_state=0) | ||
|
||
gs = GridSearchCV(LinearSVC(random_state=0), | ||
param_grid={'C': [0.1, 0.2, 0.3]}, | ||
cv=OneTimeSplitter(n_splits=n_splits, | ||
n_samples=n_samples)) | ||
gs.fit(X, y) | ||
|
||
gs2 = GridSearchCV(LinearSVC(random_state=0), | ||
param_grid={'C': [0.1, 0.2, 0.3]}, | ||
cv=KFold(n_splits=n_splits)) | ||
gs2.fit(X, y) | ||
|
||
def _pop_time_keys(cv_results): | ||
for key in ('mean_fit_time', 'std_fit_time', | ||
'mean_score_time', 'std_score_time'): | ||
cv_results.pop(key) | ||
return cv_results | ||
|
||
# OneTimeSplitter is a non-re-entrant cv where split can be called only | ||
# once if ``cv.split`` is called once per param setting in GridSearchCV.fit | ||
# the 2nd and 3rd parameter will not be evaluated as no train/test indices | ||
# will be generated for the 2nd and subsequent cv.split calls. | ||
# This is a check to make sure cv.split is not called once per param | ||
# setting. | ||
np.testing.assert_equal(_pop_time_keys(gs.cv_results_), | ||
_pop_time_keys(gs2.cv_results_)) | ||
|
||
# Check consistency of folds across the parameters | ||
gs = GridSearchCV(LinearSVC(random_state=0), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did this test fail before? Maybe also add a one-line comment explaining what's happening. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it fails on master... |
||
param_grid={'C': [0.1, 0.1, 0.2, 0.2]}, | ||
cv=KFold(n_splits=n_splits, shuffle=True)) | ||
gs.fit(X, y) | ||
|
||
# As the first two param settings (C=0.1) and the next two param | ||
# settings (C=0.2) are same, the test and train scores must also be | ||
# same as long as the same train/test indices are generated for all | ||
# the cv splits, for both param setting | ||
for score_type in ('train', 'test'): | ||
per_param_scores = {} | ||
for param_i in range(4): | ||
per_param_scores[param_i] = list( | ||
gs.cv_results_['split%d_%s_score' % (s, score_type)][param_i] | ||
for s in range(5)) | ||
|
||
assert_array_almost_equal(per_param_scores[0], | ||
per_param_scores[1]) | ||
assert_array_almost_equal(per_param_scores[2], | ||
per_param_scores[3]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,73 +59,9 @@ | |
|
||
X = np.ones(10) | ||
y = np.arange(10) // 2 | ||
P_sparse = coo_matrix(np.eye(5)) | ||
digits = load_digits() | ||
|
||
|
||
class MockClassifier(object): | ||
"""Dummy classifier to test the cross-validation""" | ||
|
||
def __init__(self, a=0, allow_nd=False): | ||
self.a = a | ||
self.allow_nd = allow_nd | ||
|
||
def fit(self, X, Y=None, sample_weight=None, class_prior=None, | ||
sparse_sample_weight=None, sparse_param=None, dummy_int=None, | ||
dummy_str=None, dummy_obj=None, callback=None): | ||
"""The dummy arguments are to test that this fit function can | ||
accept non-array arguments through cross-validation, such as: | ||
- int | ||
- str (this is actually array-like) | ||
- object | ||
- function | ||
""" | ||
self.dummy_int = dummy_int | ||
self.dummy_str = dummy_str | ||
self.dummy_obj = dummy_obj | ||
if callback is not None: | ||
callback(self) | ||
|
||
if self.allow_nd: | ||
X = X.reshape(len(X), -1) | ||
if X.ndim >= 3 and not self.allow_nd: | ||
raise ValueError('X cannot be d') | ||
if sample_weight is not None: | ||
assert_true(sample_weight.shape[0] == X.shape[0], | ||
'MockClassifier extra fit_param sample_weight.shape[0]' | ||
' is {0}, should be {1}'.format(sample_weight.shape[0], | ||
X.shape[0])) | ||
if class_prior is not None: | ||
assert_true(class_prior.shape[0] == len(np.unique(y)), | ||
'MockClassifier extra fit_param class_prior.shape[0]' | ||
' is {0}, should be {1}'.format(class_prior.shape[0], | ||
len(np.unique(y)))) | ||
if sparse_sample_weight is not None: | ||
fmt = ('MockClassifier extra fit_param sparse_sample_weight' | ||
'.shape[0] is {0}, should be {1}') | ||
assert_true(sparse_sample_weight.shape[0] == X.shape[0], | ||
fmt.format(sparse_sample_weight.shape[0], X.shape[0])) | ||
if sparse_param is not None: | ||
fmt = ('MockClassifier extra fit_param sparse_param.shape ' | ||
'is ({0}, {1}), should be ({2}, {3})') | ||
assert_true(sparse_param.shape == P_sparse.shape, | ||
fmt.format(sparse_param.shape[0], | ||
sparse_param.shape[1], | ||
P_sparse.shape[0], P_sparse.shape[1])) | ||
return self | ||
|
||
def predict(self, T): | ||
if self.allow_nd: | ||
T = T.reshape(len(T), -1) | ||
return T[:, 0] | ||
|
||
def score(self, X=None, Y=None): | ||
return 1. / (1 + np.abs(self.a)) | ||
|
||
def get_params(self, deep=False): | ||
return {'a': self.a, 'allow_nd': self.allow_nd} | ||
|
||
|
||
@ignore_warnings | ||
def test_cross_validator_with_default_params(): | ||
n_samples = 4 | ||
|
@@ -908,6 +844,22 @@ def test_cv_iterable_wrapper(): | |
# Check if get_n_splits works correctly | ||
assert_equal(len(cv), wrapped_old_skf.get_n_splits()) | ||
|
||
kf_iter = KFold(n_splits=5).split(X, y) | ||
kf_iter_wrapped = check_cv(kf_iter) | ||
# Since the wrapped iterable is enlisted and stored, | ||
# split can be called any number of times to produce | ||
# consistent results. | ||
assert_array_equal(list(kf_iter_wrapped.split(X, y)), | ||
list(kf_iter_wrapped.split(X, y))) | ||
# If the splits are randomized, successive calls to split yields different | ||
# results | ||
kf_randomized_iter = KFold(n_splits=5, shuffle=True).split(X, y) | ||
kf_randomized_iter_wrapped = check_cv(kf_randomized_iter) | ||
assert_array_equal(list(kf_randomized_iter_wrapped.split(X, y)), | ||
list(kf_randomized_iter_wrapped.split(X, y))) | ||
assert_true(np.any(np.array(list(kf_iter_wrapped.split(X, y))) != | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, generators can be iterated upon several times without being consumed. It's iterators that cannot (a generator is an object that creates an iterator upon iteration. The "yield" keyword makes generators). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think you mean that generator functions can be called many times?? |
||
np.array(list(kf_randomized_iter_wrapped.split(X, y))))) | ||
|
||
|
||
def test_group_kfold(): | ||
rng = np.random.RandomState(0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would a copy of the cv object have done the trick?