Skip to content

Commit 646e47c

Browse files
committed
Merge pull request #4294 from rvraghav93/model_selection
[MRG+1] Reorganize grid_search, cross_validation and learning_curve into model_selection
2 parents 91753dc + bdd94e9 commit 646e47c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+6175
-103
lines changed

doc/whats_new.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ New features
2525
Enhancements
2626
............
2727

28+
- The cross-validation iterators are now modified as cross-validation splitters
29+
which expose a ``split`` method that takes in the data and yields a generator
30+
for the different splits. This change makes it possible to do nested cross-validation
31+
with ease. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.
32+
33+
- The :mod:`cross_validation`, :mod:`grid_search` and :mod:`learning_curve`
34+
have been deprecated and the classes and functions have been reorganized into
35+
the :mod:`model_selection` module. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.
36+
37+
2838
Bug fixes
2939
.........
3040

sklearn/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@
6262
'ensemble', 'exceptions', 'externals', 'feature_extraction',
6363
'feature_selection', 'gaussian_process', 'grid_search',
6464
'isotonic', 'kernel_approximation', 'kernel_ridge',
65-
'lda', 'learning_curve',
66-
'linear_model', 'manifold', 'metrics', 'mixture', 'multiclass',
65+
'lda', 'learning_curve', 'linear_model', 'manifold', 'metrics',
66+
'mixture', 'model_selection', 'multiclass',
6767
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
6868
'preprocessing', 'qda', 'random_projection', 'semi_supervised',
6969
'svm', 'tree', 'discriminant_analysis',

sklearn/calibration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .utils.fixes import signature
2323
from .isotonic import IsotonicRegression
2424
from .svm import LinearSVC
25-
from .cross_validation import check_cv
25+
from .model_selection import check_cv
2626
from .metrics.classification import _check_binary_probabilistic_predictions
2727

2828

@@ -152,7 +152,7 @@ def fit(self, X, y, sample_weight=None):
152152
calibrated_classifier.fit(X, y)
153153
self.calibrated_classifiers_.append(calibrated_classifier)
154154
else:
155-
cv = check_cv(self.cv, X, y, classifier=True)
155+
cv = check_cv(self.cv, y, classifier=True)
156156
fit_parameters = signature(base_estimator.fit).parameters
157157
estimator_name = type(base_estimator).__name__
158158
if (sample_weight is not None
@@ -163,7 +163,7 @@ def fit(self, X, y, sample_weight=None):
163163
base_estimator_sample_weight = None
164164
else:
165165
base_estimator_sample_weight = sample_weight
166-
for train, test in cv:
166+
for train, test in cv.split(X, y):
167167
this_estimator = clone(base_estimator)
168168
if base_estimator_sample_weight is not None:
169169
this_estimator.fit(

sklearn/cluster/tests/test_bicluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy.sparse import csr_matrix, issparse
55

6-
from sklearn.grid_search import ParameterGrid
6+
from sklearn.model_selection import ParameterGrid
77

88
from sklearn.utils.testing import assert_equal
99
from sklearn.utils.testing import assert_almost_equal

sklearn/covariance/graph_lasso_.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..utils.validation import check_random_state, check_array
2222
from ..linear_model import lars_path
2323
from ..linear_model import cd_fast
24-
from ..cross_validation import check_cv, cross_val_score
24+
from ..model_selection import check_cv, cross_val_score
2525
from ..externals.joblib import Parallel, delayed
2626
import collections
2727

@@ -580,7 +580,7 @@ def fit(self, X, y=None):
580580
emp_cov = empirical_covariance(
581581
X, assume_centered=self.assume_centered)
582582

583-
cv = check_cv(self.cv, X, y, classifier=False)
583+
cv = check_cv(self.cv, y, classifier=False)
584584

585585
# List of (alpha, scores, covs)
586586
path = list()
@@ -612,14 +612,13 @@ def fit(self, X, y=None):
612612
this_path = Parallel(
613613
n_jobs=self.n_jobs,
614614
verbose=self.verbose
615-
)(
616-
delayed(graph_lasso_path)(
617-
X[train], alphas=alphas,
618-
X_test=X[test], mode=self.mode,
619-
tol=self.tol, enet_tol=self.enet_tol,
620-
max_iter=int(.1 * self.max_iter),
621-
verbose=inner_verbose)
622-
for train, test in cv)
615+
)(delayed(graph_lasso_path)(X[train], alphas=alphas,
616+
X_test=X[test], mode=self.mode,
617+
tol=self.tol,
618+
enet_tol=self.enet_tol,
619+
max_iter=int(.1 * self.max_iter),
620+
verbose=inner_verbose)
621+
for train, test in cv.split(X, y))
623622

624623
# Little danse to transform the list in what we need
625624
covs, _, scores = zip(*this_path)

sklearn/cross_validation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
from .gaussian_process.kernels import Kernel as GPKernel
3535
from .exceptions import FitFailedWarning
3636

37+
38+
warnings.warn("This module has been deprecated in favor of the "
39+
"model_selection module into which all the refactored classes "
40+
"and functions are moved. Also note that the interface of the "
41+
"new CV iterators are different from that of this module. "
42+
"This module will be removed in 0.19.", DeprecationWarning)
43+
44+
3745
__all__ = ['KFold',
3846
'LabelKFold',
3947
'LeaveOneLabelOut',
@@ -304,7 +312,7 @@ class KFold(_BaseKFold):
304312
305313
See also
306314
--------
307-
StratifiedKFold: take label information into account to avoid building
315+
StratifiedKFold take label information into account to avoid building
308316
folds with imbalanced class distributions (for binary or multiclass
309317
classification tasks).
310318

sklearn/decomposition/tests/test_kernel_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.datasets import make_circles
1010
from sklearn.linear_model import Perceptron
1111
from sklearn.pipeline import Pipeline
12-
from sklearn.grid_search import GridSearchCV
12+
from sklearn.model_selection import GridSearchCV
1313
from sklearn.metrics.pairwise import rbf_kernel
1414

1515

sklearn/ensemble/tests/test_bagging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
from sklearn.utils.testing import assert_warns_message
2222

2323
from sklearn.dummy import DummyClassifier, DummyRegressor
24-
from sklearn.grid_search import GridSearchCV, ParameterGrid
24+
from sklearn.model_selection import GridSearchCV, ParameterGrid
2525
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
2626
from sklearn.linear_model import Perceptron, LogisticRegression
2727
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2828
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2929
from sklearn.svm import SVC, SVR
3030
from sklearn.pipeline import make_pipeline
3131
from sklearn.feature_selection import SelectKBest
32-
from sklearn.cross_validation import train_test_split
32+
from sklearn.model_selection import train_test_split
3333
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
3434
from sklearn.utils import check_random_state
3535

sklearn/ensemble/tests/test_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sklearn.ensemble import RandomForestClassifier
3939
from sklearn.ensemble import RandomForestRegressor
4040
from sklearn.ensemble import RandomTreesEmbedding
41-
from sklearn.grid_search import GridSearchCV
41+
from sklearn.model_selection import GridSearchCV
4242
from sklearn.svm import LinearSVC
4343
from sklearn.utils.fixes import bincount
4444
from sklearn.utils.validation import check_random_state

sklearn/ensemble/tests/test_voting_classifier.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from sklearn.naive_bayes import GaussianNB
88
from sklearn.ensemble import RandomForestClassifier
99
from sklearn.ensemble import VotingClassifier
10-
from sklearn.grid_search import GridSearchCV
10+
from sklearn.model_selection import GridSearchCV
1111
from sklearn import datasets
12-
from sklearn import cross_validation
12+
from sklearn.model_selection import cross_val_score
1313
from sklearn.datasets import make_multilabel_classification
1414
from sklearn.svm import SVC
1515
from sklearn.multiclass import OneVsRestClassifier
@@ -27,11 +27,7 @@ def test_majority_label_iris():
2727
eclf = VotingClassifier(estimators=[
2828
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
2929
voting='hard')
30-
scores = cross_validation.cross_val_score(eclf,
31-
X,
32-
y,
33-
cv=5,
34-
scoring='accuracy')
30+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
3531
assert_almost_equal(scores.mean(), 0.95, decimal=2)
3632

3733

@@ -55,11 +51,7 @@ def test_weights_iris():
5551
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
5652
voting='soft',
5753
weights=[1, 2, 10])
58-
scores = cross_validation.cross_val_score(eclf,
59-
X,
60-
y,
61-
cv=5,
62-
scoring='accuracy')
54+
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
6355
assert_almost_equal(scores.mean(), 0.93, decimal=2)
6456

6557

0 commit comments

Comments
 (0)