Skip to content

[MRG+1] Make cross-validators data independent + Reorganize grid_search, cross_validation and learning_curve into model_selection #4294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 23, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ New features
Enhancements
............

- The cross-validation iterators are now modified as cross-validation splitters
which expose a ``split`` method that takes in the data and yields a generator
for the different splits. This change makes it possible to do nested cross-validation
with ease. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.

- The :mod:`cross_validation`, :mod:`grid_search` and :mod:`learning_curve`
have been deprecated and the classes and functions have been reorganized into
the :mod:`model_selection` module. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_.


Bug fixes
.........

Expand Down
4 changes: 2 additions & 2 deletions sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
'ensemble', 'exceptions', 'externals', 'feature_extraction',
'feature_selection', 'gaussian_process', 'grid_search',
'isotonic', 'kernel_approximation', 'kernel_ridge',
'lda', 'learning_curve',
'linear_model', 'manifold', 'metrics', 'mixture', 'multiclass',
'lda', 'learning_curve', 'linear_model', 'manifold', 'metrics',
'mixture', 'model_selection', 'multiclass',
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
'preprocessing', 'qda', 'random_projection', 'semi_supervised',
'svm', 'tree', 'discriminant_analysis',
Expand Down
6 changes: 3 additions & 3 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .utils.fixes import signature
from .isotonic import IsotonicRegression
from .svm import LinearSVC
from .cross_validation import check_cv
from .model_selection import check_cv
from .metrics.classification import _check_binary_probabilistic_predictions


Expand Down Expand Up @@ -152,7 +152,7 @@ def fit(self, X, y, sample_weight=None):
calibrated_classifier.fit(X, y)
self.calibrated_classifiers_.append(calibrated_classifier)
else:
cv = check_cv(self.cv, X, y, classifier=True)
cv = check_cv(self.cv, y, classifier=True)
fit_parameters = signature(base_estimator.fit).parameters
estimator_name = type(base_estimator).__name__
if (sample_weight is not None
Expand All @@ -163,7 +163,7 @@ def fit(self, X, y, sample_weight=None):
base_estimator_sample_weight = None
else:
base_estimator_sample_weight = sample_weight
for train, test in cv:
for train, test in cv.split(X, y):
this_estimator = clone(base_estimator)
if base_estimator_sample_weight is not None:
this_estimator.fit(
Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/tests/test_bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from scipy.sparse import csr_matrix, issparse

from sklearn.grid_search import ParameterGrid
from sklearn.model_selection import ParameterGrid

from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_almost_equal
Expand Down
19 changes: 9 additions & 10 deletions sklearn/covariance/graph_lasso_.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..utils.validation import check_random_state, check_array
from ..linear_model import lars_path
from ..linear_model import cd_fast
from ..cross_validation import check_cv, cross_val_score
from ..model_selection import check_cv, cross_val_score
from ..externals.joblib import Parallel, delayed
import collections

Expand Down Expand Up @@ -580,7 +580,7 @@ def fit(self, X, y=None):
emp_cov = empirical_covariance(
X, assume_centered=self.assume_centered)

cv = check_cv(self.cv, X, y, classifier=False)
cv = check_cv(self.cv, y, classifier=False)

# List of (alpha, scores, covs)
path = list()
Expand Down Expand Up @@ -612,14 +612,13 @@ def fit(self, X, y=None):
this_path = Parallel(
n_jobs=self.n_jobs,
verbose=self.verbose
)(
delayed(graph_lasso_path)(
X[train], alphas=alphas,
X_test=X[test], mode=self.mode,
tol=self.tol, enet_tol=self.enet_tol,
max_iter=int(.1 * self.max_iter),
verbose=inner_verbose)
for train, test in cv)
)(delayed(graph_lasso_path)(X[train], alphas=alphas,
X_test=X[test], mode=self.mode,
tol=self.tol,
enet_tol=self.enet_tol,
max_iter=int(.1 * self.max_iter),
verbose=inner_verbose)
for train, test in cv.split(X, y))

# Little danse to transform the list in what we need
covs, _, scores = zip(*this_path)
Expand Down
10 changes: 9 additions & 1 deletion sklearn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@
from .gaussian_process.kernels import Kernel as GPKernel
from .exceptions import FitFailedWarning


warnings.warn("This module has been deprecated in favor of the "
"model_selection module into which all the refactored classes "
"and functions are moved. Also note that the interface of the "
"new CV iterators are different from that of this module. "
"This module will be removed in 0.19.", DeprecationWarning)


__all__ = ['KFold',
'LabelKFold',
'LeaveOneLabelOut',
Expand Down Expand Up @@ -304,7 +312,7 @@ class KFold(_BaseKFold):

See also
--------
StratifiedKFold: take label information into account to avoid building
StratifiedKFold take label information into account to avoid building
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not the colon that messes up the formatting here. Sphinx takes the first word in each line and auto-links it.
sphinx

Use the numpydoc syntax: keep the colon, but indent the lines of the description. This should do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not addressed yet

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC for multi line description the colon needs to be removed I think...

folds with imbalanced class distributions (for binary or multiclass
classification tasks).

Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/tests/test_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.datasets import make_circles
from sklearn.linear_model import Perceptron
from sklearn.pipeline import Pipeline
from sklearn.grid_search import GridSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn.metrics.pairwise import rbf_kernel


Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/tests/test_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from sklearn.utils.testing import assert_warns_message

from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.grid_search import GridSearchCV, ParameterGrid
from sklearn.model_selection import GridSearchCV, ParameterGrid
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
from sklearn.linear_model import Perceptron, LogisticRegression
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.svm import SVC, SVR
from sklearn.pipeline import make_pipeline
from sklearn.feature_selection import SelectKBest
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
from sklearn.utils import check_random_state

Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomTreesEmbedding
from sklearn.grid_search import GridSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC
from sklearn.utils.fixes import bincount
from sklearn.utils.validation import check_random_state
Expand Down
16 changes: 4 additions & 12 deletions sklearn/ensemble/tests/test_voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import VotingClassifier
from sklearn.grid_search import GridSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn import datasets
from sklearn import cross_validation
from sklearn.model_selection import cross_val_score
from sklearn.datasets import make_multilabel_classification
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
Expand All @@ -27,11 +27,7 @@ def test_majority_label_iris():
eclf = VotingClassifier(estimators=[
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
voting='hard')
scores = cross_validation.cross_val_score(eclf,
X,
y,
cv=5,
scoring='accuracy')
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
assert_almost_equal(scores.mean(), 0.95, decimal=2)


Expand All @@ -55,11 +51,7 @@ def test_weights_iris():
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
voting='soft',
weights=[1, 2, 10])
scores = cross_validation.cross_val_score(eclf,
X,
y,
cv=5,
scoring='accuracy')
scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy')
assert_almost_equal(scores.mean(), 0.93, decimal=2)


Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/tests/test_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from sklearn.utils.testing import assert_raises, assert_raises_regexp

from sklearn.base import BaseEstimator
from sklearn.cross_validation import train_test_split
from sklearn.grid_search import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import AdaBoostRegressor
from sklearn.ensemble import weight_boosting
Expand Down
2 changes: 1 addition & 1 deletion sklearn/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class FitFailedWarning(RuntimeWarning):

Examples
--------
>>> from sklearn.grid_search import GridSearchCV
>>> from sklearn.model_selection import GridSearchCV
>>> from sklearn.svm import LinearSVC
>>> from sklearn.exceptions import FitFailedWarning
>>> import warnings
Expand Down
6 changes: 3 additions & 3 deletions sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

from sklearn.cross_validation import train_test_split
from sklearn.cross_validation import cross_val_score
from sklearn.grid_search import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC

Expand Down
14 changes: 7 additions & 7 deletions sklearn/feature_selection/rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from ..base import MetaEstimatorMixin
from ..base import clone
from ..base import is_classifier
from ..cross_validation import check_cv
from ..cross_validation import _safe_split, _score
from ..model_selection import check_cv
from ..model_selection._validation import _safe_split, _score
from ..metrics.scorer import check_scoring
from .base import SelectorMixin

Expand Down Expand Up @@ -373,7 +373,7 @@ def fit(self, X, y):
X, y = check_X_y(X, y, "csr")

# Initialization
cv = check_cv(self.cv, X, y, is_classifier(self.estimator))
cv = check_cv(self.cv, y, is_classifier(self.estimator))
scorer = check_scoring(self.estimator, scoring=self.scoring)
n_features = X.shape[1]
n_features_to_select = 1
Expand All @@ -382,7 +382,7 @@ def fit(self, X, y):
scores = []

# Cross-validation
for n, (train, test) in enumerate(cv):
for n, (train, test) in enumerate(cv.split(X, y)):
X_train, y_train = _safe_split(self.estimator, X, y, train)
X_test, y_test = _safe_split(self.estimator, X, y, test, train)

Expand Down Expand Up @@ -414,7 +414,7 @@ def fit(self, X, y):
self.estimator_ = clone(self.estimator)
self.estimator_.fit(self.transform(X), y)

# Fixing a normalization error, n is equal to len(cv) - 1
# here, the scores are normalized by len(cv)
self.grid_scores_ = scores / len(cv)
# Fixing a normalization error, n is equal to get_n_splits(X, y) - 1
# here, the scores are normalized by get_n_splits(X, y)
self.grid_scores_ = scores / cv.get_n_splits(X, y)
return self
2 changes: 1 addition & 1 deletion sklearn/feature_selection/tests/test_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sklearn.metrics import zero_one_loss
from sklearn.svm import SVC, SVR
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import cross_val_score
from sklearn.model_selection import cross_val_score

from sklearn.utils import check_random_state
from sklearn.utils.testing import ignore_warnings
Expand Down
6 changes: 6 additions & 0 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
'ParameterSampler', 'RandomizedSearchCV']


warnings.warn("This module has been deprecated in favor of the "
"model_selection module into which all the refactored classes "
"and functions are moved. This module will be removed in 0.19.",
DeprecationWarning)


class ParameterGrid(object):
"""Grid of parameters with a discrete number of values for each.

Expand Down
6 changes: 6 additions & 0 deletions sklearn/learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from .utils.fixes import astype


warnings.warn("This module has been deprecated in favor of the "
"model_selection module into which all the functions are moved."
" This module will be removed in 0.19",
DeprecationWarning)


__all__ = ['learning_curve', 'validation_curve']


Expand Down
6 changes: 3 additions & 3 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .base import center_data, sparse_center_data
from ..utils import check_array, check_X_y, deprecated
from ..utils.validation import check_random_state
from ..cross_validation import check_cv
from ..model_selection import check_cv
from ..externals.joblib import Parallel, delayed
from ..externals import six
from ..externals.six.moves import xrange
Expand Down Expand Up @@ -1120,10 +1120,10 @@ def fit(self, X, y):
path_params['copy_X'] = False

# init cross-validation generator
cv = check_cv(self.cv, X)
cv = check_cv(self.cv)

# Compute path for all folds and compute MSE to get the best alpha
folds = list(cv)
folds = list(cv.split(X))
best_mse = np.inf

# We do a double for loop folded in one, in order to be able to
Expand Down
6 changes: 3 additions & 3 deletions sklearn/linear_model/least_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .base import LinearModel
from ..base import RegressorMixin
from ..utils import arrayfuncs, as_float_array, check_X_y
from ..cross_validation import check_cv
from ..model_selection import check_cv
from ..exceptions import ConvergenceWarning
from ..externals.joblib import Parallel, delayed
from ..externals.six.moves import xrange
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def fit(self, X, y):
y = as_float_array(y, copy=self.copy_X)

# init cross-validation generator
cv = check_cv(self.cv, X, y, classifier=False)
cv = check_cv(self.cv, classifier=False)

Gram = 'auto' if self.precompute else None

Expand All @@ -1089,7 +1089,7 @@ def fit(self, X, y):
method=self.method, verbose=max(0, self.verbose - 1),
normalize=self.normalize, fit_intercept=self.fit_intercept,
max_iter=self.max_iter, eps=self.eps, positive=self.positive)
for train, test in cv)
for train, test in cv.split(X, y))
all_alphas = np.concatenate(list(zip(*cv_paths))[0])
# Unique also sorts
all_alphas = np.unique(all_alphas)
Expand Down
9 changes: 5 additions & 4 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

"""
Logistic Regression
"""
Expand Down Expand Up @@ -32,7 +33,7 @@
from ..utils.fixes import expit
from ..utils.multiclass import check_classification_targets
from ..externals.joblib import Parallel, delayed
from ..cross_validation import check_cv
from ..model_selection import check_cv
from ..externals import six
from ..metrics import SCORERS

Expand Down Expand Up @@ -1309,7 +1310,7 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator,
cv : integer or cross-validation generator
The default cross-validation generator used is Stratified K-Folds.
If an integer is provided, then it is the number of folds used.
See the module :mod:`sklearn.cross_validation` module for the
See the module :mod:`sklearn.model_selection` module for the
list of possible cross-validation objects.

penalty : str, 'l1' or 'l2'
Expand Down Expand Up @@ -1506,8 +1507,8 @@ def fit(self, X, y, sample_weight=None):
check_consistent_length(X, y)

# init cross-validation generator
cv = check_cv(self.cv, X, y, classifier=True)
folds = list(cv)
cv = check_cv(self.cv, y, classifier=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need y? I guess to be future proof in case we implement multi-label LR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goes into #5053 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, that was for backward-incompatible changes, not minor refactoring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well y can't simply be removed without more changes (if y=None then stratification is not used in check_cv). The easiest would be to leave it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amueller do I leave it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

folds = list(cv.split(X, y))

self._enc = LabelEncoder()
self._enc.fit(y)
Expand Down
Loading