From aeaed79e0d23908612fab5cce91d4633de3c9d4b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 25 Sep 2018 10:48:34 +0200 Subject: [PATCH 1/3] BUG: check equality instead of identity in check_cv --- sklearn/model_selection/_split.py | 2 +- sklearn/model_selection/tests/test_validation.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 75c8e5d239d08..0aefbeb7f04e2 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -1939,7 +1939,7 @@ def check_cv(cv='warn', y=None, classifier=False): The return value is a cross-validator which generates the train/test splits via the ``split`` method. """ - if cv is None or cv is 'warn': + if cv is None or cv == 'warn': warnings.warn(CV_WARNING, FutureWarning) cv = 3 diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 0d7a05f39d714..9f50ffe67ff3e 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -281,6 +281,17 @@ def test_cross_val_score(): error_score='raise') +@pytest.mark.filterwarnings('ignore:You should specify a value for') # 0.22 +def test_cross_validate_many_jobs(): + # regression test for #12154: cv='warn' with n_jobs=-1 trigger a copy of + # the parameters leading to a failure in check_cv due to cv is 'warn' + # instead of cv == 'warn'. + X, y = load_iris(return_X_y=True) + clf = SVC(gamma='auto') + grid = GridSearchCV(clf, param_grid={'C': [1, 10]}) + cross_validate(grid, X, y, n_jobs=-1) + + @pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22 def test_cross_validate_invalid_scoring_param(): X, y = make_classification(random_state=0) From 34780366ada14dc86a80f499444631e794df4f68 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 25 Sep 2018 11:23:09 +0200 Subject: [PATCH 2/3] Update test_validation.py --- sklearn/model_selection/tests/test_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 9f50ffe67ff3e..4d83db99d64c9 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -283,13 +283,13 @@ def test_cross_val_score(): @pytest.mark.filterwarnings('ignore:You should specify a value for') # 0.22 def test_cross_validate_many_jobs(): - # regression test for #12154: cv='warn' with n_jobs=-1 trigger a copy of + # regression test for #12154: cv='warn' with n_jobs>1 trigger a copy of # the parameters leading to a failure in check_cv due to cv is 'warn' # instead of cv == 'warn'. X, y = load_iris(return_X_y=True) clf = SVC(gamma='auto') grid = GridSearchCV(clf, param_grid={'C': [1, 10]}) - cross_validate(grid, X, y, n_jobs=-1) + cross_validate(grid, X, y, n_jobs=2) @pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22 From 16290cdb8f00162e97fe2416be75643b19cd243e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 25 Sep 2018 11:35:43 +0200 Subject: [PATCH 3/3] FIX: changes other occurences --- sklearn/model_selection/_split.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 0aefbeb7f04e2..954a6c2bd443e 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -422,7 +422,7 @@ class KFold(_BaseKFold): def __init__(self, n_splits='warn', shuffle=False, random_state=None): - if n_splits is 'warn': + if n_splits == 'warn': warnings.warn(NSPLIT_WARNING, FutureWarning) n_splits = 3 super(KFold, self).__init__(n_splits, shuffle, random_state) @@ -493,7 +493,7 @@ class GroupKFold(_BaseKFold): stratification of the dataset. """ def __init__(self, n_splits='warn'): - if n_splits is 'warn': + if n_splits == 'warn': warnings.warn(NSPLIT_WARNING, FutureWarning) n_splits = 3 super(GroupKFold, self).__init__(n_splits, shuffle=False, @@ -594,7 +594,7 @@ class StratifiedKFold(_BaseKFold): """ def __init__(self, n_splits='warn', shuffle=False, random_state=None): - if n_splits is 'warn': + if n_splits == 'warn': warnings.warn(NSPLIT_WARNING, FutureWarning) n_splits = 3 super(StratifiedKFold, self).__init__(n_splits, shuffle, random_state) @@ -748,7 +748,7 @@ class TimeSeriesSplit(_BaseKFold): where ``n_samples`` is the number of samples. """ def __init__(self, n_splits='warn', max_train_size=None): - if n_splits is 'warn': + if n_splits == 'warn': warnings.warn(NSPLIT_WARNING, FutureWarning) n_splits = 3 super(TimeSeriesSplit, self).__init__(n_splits,