diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 5e78d84c51956..420e2609adef1 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -25,6 +25,16 @@ New features Enhancements ............ + - The cross-validation iterators are now modified as cross-validation splitters + which expose a ``split`` method that takes in the data and yields a generator + for the different splits. This change makes it possible to do nested cross-validation + with ease. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_. + + - The :mod:`cross_validation`, :mod:`grid_search` and :mod:`learning_curve` + have been deprecated and the classes and functions have been reorganized into + the :mod:`model_selection` module. (`#4294 https://github.com/scikit-learn/scikit-learn/pull/4294>`_) by `Raghav R V`_. + + Bug fixes ......... diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 421d503cf7f41..9eafddfe7ca99 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -62,8 +62,8 @@ 'ensemble', 'exceptions', 'externals', 'feature_extraction', 'feature_selection', 'gaussian_process', 'grid_search', 'isotonic', 'kernel_approximation', 'kernel_ridge', - 'lda', 'learning_curve', - 'linear_model', 'manifold', 'metrics', 'mixture', 'multiclass', + 'lda', 'learning_curve', 'linear_model', 'manifold', 'metrics', + 'mixture', 'model_selection', 'multiclass', 'naive_bayes', 'neighbors', 'neural_network', 'pipeline', 'preprocessing', 'qda', 'random_projection', 'semi_supervised', 'svm', 'tree', 'discriminant_analysis', diff --git a/sklearn/calibration.py b/sklearn/calibration.py index bc2e5b0c28274..f974977162f1c 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -22,7 +22,7 @@ from .utils.fixes import signature from .isotonic import IsotonicRegression from .svm import LinearSVC -from .cross_validation import check_cv +from .model_selection import check_cv from .metrics.classification import _check_binary_probabilistic_predictions @@ -152,7 +152,7 @@ def fit(self, X, y, sample_weight=None): calibrated_classifier.fit(X, y) self.calibrated_classifiers_.append(calibrated_classifier) else: - cv = check_cv(self.cv, X, y, classifier=True) + cv = check_cv(self.cv, y, classifier=True) fit_parameters = signature(base_estimator.fit).parameters estimator_name = type(base_estimator).__name__ if (sample_weight is not None @@ -163,7 +163,7 @@ def fit(self, X, y, sample_weight=None): base_estimator_sample_weight = None else: base_estimator_sample_weight = sample_weight - for train, test in cv: + for train, test in cv.split(X, y): this_estimator = clone(base_estimator) if base_estimator_sample_weight is not None: this_estimator.fit( diff --git a/sklearn/cluster/tests/test_bicluster.py b/sklearn/cluster/tests/test_bicluster.py index 9afcca5ff0eec..eacc208d4ef08 100644 --- a/sklearn/cluster/tests/test_bicluster.py +++ b/sklearn/cluster/tests/test_bicluster.py @@ -3,7 +3,7 @@ import numpy as np from scipy.sparse import csr_matrix, issparse -from sklearn.grid_search import ParameterGrid +from sklearn.model_selection import ParameterGrid from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_almost_equal diff --git a/sklearn/covariance/graph_lasso_.py b/sklearn/covariance/graph_lasso_.py index dd08ed4206120..7d9568823322d 100644 --- a/sklearn/covariance/graph_lasso_.py +++ b/sklearn/covariance/graph_lasso_.py @@ -21,7 +21,7 @@ from ..utils.validation import check_random_state, check_array from ..linear_model import lars_path from ..linear_model import cd_fast -from ..cross_validation import check_cv, cross_val_score +from ..model_selection import check_cv, cross_val_score from ..externals.joblib import Parallel, delayed import collections @@ -580,7 +580,7 @@ def fit(self, X, y=None): emp_cov = empirical_covariance( X, assume_centered=self.assume_centered) - cv = check_cv(self.cv, X, y, classifier=False) + cv = check_cv(self.cv, y, classifier=False) # List of (alpha, scores, covs) path = list() @@ -612,14 +612,13 @@ def fit(self, X, y=None): this_path = Parallel( n_jobs=self.n_jobs, verbose=self.verbose - )( - delayed(graph_lasso_path)( - X[train], alphas=alphas, - X_test=X[test], mode=self.mode, - tol=self.tol, enet_tol=self.enet_tol, - max_iter=int(.1 * self.max_iter), - verbose=inner_verbose) - for train, test in cv) + )(delayed(graph_lasso_path)(X[train], alphas=alphas, + X_test=X[test], mode=self.mode, + tol=self.tol, + enet_tol=self.enet_tol, + max_iter=int(.1 * self.max_iter), + verbose=inner_verbose) + for train, test in cv.split(X, y)) # Little danse to transform the list in what we need covs, _, scores = zip(*this_path) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 84a8813346357..6ee9e02d9d984 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -34,6 +34,14 @@ from .gaussian_process.kernels import Kernel as GPKernel from .exceptions import FitFailedWarning + +warnings.warn("This module has been deprecated in favor of the " + "model_selection module into which all the refactored classes " + "and functions are moved. Also note that the interface of the " + "new CV iterators are different from that of this module. " + "This module will be removed in 0.19.", DeprecationWarning) + + __all__ = ['KFold', 'LabelKFold', 'LeaveOneLabelOut', @@ -304,7 +312,7 @@ class KFold(_BaseKFold): See also -------- - StratifiedKFold: take label information into account to avoid building + StratifiedKFold take label information into account to avoid building folds with imbalanced class distributions (for binary or multiclass classification tasks). diff --git a/sklearn/decomposition/tests/test_kernel_pca.py b/sklearn/decomposition/tests/test_kernel_pca.py index 3e7af87aa68c8..3899d91ac33d2 100644 --- a/sklearn/decomposition/tests/test_kernel_pca.py +++ b/sklearn/decomposition/tests/test_kernel_pca.py @@ -9,7 +9,7 @@ from sklearn.datasets import make_circles from sklearn.linear_model import Perceptron from sklearn.pipeline import Pipeline -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import GridSearchCV from sklearn.metrics.pairwise import rbf_kernel diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 12852d7f09b1d..85bacf5a14786 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -21,7 +21,7 @@ from sklearn.utils.testing import assert_warns_message from sklearn.dummy import DummyClassifier, DummyRegressor -from sklearn.grid_search import GridSearchCV, ParameterGrid +from sklearn.model_selection import GridSearchCV, ParameterGrid from sklearn.ensemble import BaggingClassifier, BaggingRegressor from sklearn.linear_model import Perceptron, LogisticRegression from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor @@ -29,7 +29,7 @@ from sklearn.svm import SVC, SVR from sklearn.pipeline import make_pipeline from sklearn.feature_selection import SelectKBest -from sklearn.cross_validation import train_test_split +from sklearn.model_selection import train_test_split from sklearn.datasets import load_boston, load_iris, make_hastie_10_2 from sklearn.utils import check_random_state diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 205829ba3d004..b4a5197185e9b 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -38,7 +38,7 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestRegressor from sklearn.ensemble import RandomTreesEmbedding -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import GridSearchCV from sklearn.svm import LinearSVC from sklearn.utils.fixes import bincount from sklearn.utils.validation import check_random_state diff --git a/sklearn/ensemble/tests/test_voting_classifier.py b/sklearn/ensemble/tests/test_voting_classifier.py index fb86d2ec46ea2..de680e49e3e0b 100644 --- a/sklearn/ensemble/tests/test_voting_classifier.py +++ b/sklearn/ensemble/tests/test_voting_classifier.py @@ -7,9 +7,9 @@ from sklearn.naive_bayes import GaussianNB from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import VotingClassifier -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import GridSearchCV from sklearn import datasets -from sklearn import cross_validation +from sklearn.model_selection import cross_val_score from sklearn.datasets import make_multilabel_classification from sklearn.svm import SVC from sklearn.multiclass import OneVsRestClassifier @@ -27,11 +27,7 @@ def test_majority_label_iris(): eclf = VotingClassifier(estimators=[ ('lr', clf1), ('rf', clf2), ('gnb', clf3)], voting='hard') - scores = cross_validation.cross_val_score(eclf, - X, - y, - cv=5, - scoring='accuracy') + scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy') assert_almost_equal(scores.mean(), 0.95, decimal=2) @@ -55,11 +51,7 @@ def test_weights_iris(): ('lr', clf1), ('rf', clf2), ('gnb', clf3)], voting='soft', weights=[1, 2, 10]) - scores = cross_validation.cross_val_score(eclf, - X, - y, - cv=5, - scoring='accuracy') + scores = cross_val_score(eclf, X, y, cv=5, scoring='accuracy') assert_almost_equal(scores.mean(), 0.93, decimal=2) diff --git a/sklearn/ensemble/tests/test_weight_boosting.py b/sklearn/ensemble/tests/test_weight_boosting.py index f2cb4b0ba8006..83a5e819c13f7 100755 --- a/sklearn/ensemble/tests/test_weight_boosting.py +++ b/sklearn/ensemble/tests/test_weight_boosting.py @@ -7,8 +7,8 @@ from sklearn.utils.testing import assert_raises, assert_raises_regexp from sklearn.base import BaseEstimator -from sklearn.cross_validation import train_test_split -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import train_test_split +from sklearn.model_selection import GridSearchCV from sklearn.ensemble import AdaBoostClassifier from sklearn.ensemble import AdaBoostRegressor from sklearn.ensemble import weight_boosting diff --git a/sklearn/exceptions.py b/sklearn/exceptions.py index 2a99d67a4703e..472e81e2fdab5 100644 --- a/sklearn/exceptions.py +++ b/sklearn/exceptions.py @@ -85,7 +85,7 @@ class FitFailedWarning(RuntimeWarning): Examples -------- - >>> from sklearn.grid_search import GridSearchCV + >>> from sklearn.model_selection import GridSearchCV >>> from sklearn.svm import LinearSVC >>> from sklearn.exceptions import FitFailedWarning >>> import warnings diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 396c5f6d2112c..8fe7e9fe891ab 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -12,9 +12,9 @@ from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS -from sklearn.cross_validation import train_test_split -from sklearn.cross_validation import cross_val_score -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import train_test_split +from sklearn.model_selection import cross_val_score +from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC diff --git a/sklearn/feature_selection/rfe.py b/sklearn/feature_selection/rfe.py index 81d464ec6c009..84e51dc9b8663 100644 --- a/sklearn/feature_selection/rfe.py +++ b/sklearn/feature_selection/rfe.py @@ -14,8 +14,8 @@ from ..base import MetaEstimatorMixin from ..base import clone from ..base import is_classifier -from ..cross_validation import check_cv -from ..cross_validation import _safe_split, _score +from ..model_selection import check_cv +from ..model_selection._validation import _safe_split, _score from ..metrics.scorer import check_scoring from .base import SelectorMixin @@ -373,7 +373,7 @@ def fit(self, X, y): X, y = check_X_y(X, y, "csr") # Initialization - cv = check_cv(self.cv, X, y, is_classifier(self.estimator)) + cv = check_cv(self.cv, y, is_classifier(self.estimator)) scorer = check_scoring(self.estimator, scoring=self.scoring) n_features = X.shape[1] n_features_to_select = 1 @@ -382,7 +382,7 @@ def fit(self, X, y): scores = [] # Cross-validation - for n, (train, test) in enumerate(cv): + for n, (train, test) in enumerate(cv.split(X, y)): X_train, y_train = _safe_split(self.estimator, X, y, train) X_test, y_test = _safe_split(self.estimator, X, y, test, train) @@ -414,7 +414,7 @@ def fit(self, X, y): self.estimator_ = clone(self.estimator) self.estimator_.fit(self.transform(X), y) - # Fixing a normalization error, n is equal to len(cv) - 1 - # here, the scores are normalized by len(cv) - self.grid_scores_ = scores / len(cv) + # Fixing a normalization error, n is equal to get_n_splits(X, y) - 1 + # here, the scores are normalized by get_n_splits(X, y) + self.grid_scores_ = scores / cv.get_n_splits(X, y) return self diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index 7a0862e6ad10b..14bb4ba2dea7e 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -12,7 +12,7 @@ from sklearn.metrics import zero_one_loss from sklearn.svm import SVC, SVR from sklearn.ensemble import RandomForestClassifier -from sklearn.cross_validation import cross_val_score +from sklearn.model_selection import cross_val_score from sklearn.utils import check_random_state from sklearn.utils.testing import ignore_warnings diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index d6fe1e0f5dcaf..a26e95ea67d11 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -37,6 +37,12 @@ 'ParameterSampler', 'RandomizedSearchCV'] +warnings.warn("This module has been deprecated in favor of the " + "model_selection module into which all the refactored classes " + "and functions are moved. This module will be removed in 0.19.", + DeprecationWarning) + + class ParameterGrid(object): """Grid of parameters with a discrete number of values for each. diff --git a/sklearn/learning_curve.py b/sklearn/learning_curve.py index ae5601483c8fc..918102a10aa25 100644 --- a/sklearn/learning_curve.py +++ b/sklearn/learning_curve.py @@ -17,6 +17,12 @@ from .utils.fixes import astype +warnings.warn("This module has been deprecated in favor of the " + "model_selection module into which all the functions are moved." + " This module will be removed in 0.19", + DeprecationWarning) + + __all__ = ['learning_curve', 'validation_curve'] diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index b5d26cb2e7350..5f0e55b4fcfed 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -17,7 +17,7 @@ from .base import center_data, sparse_center_data from ..utils import check_array, check_X_y, deprecated from ..utils.validation import check_random_state -from ..cross_validation import check_cv +from ..model_selection import check_cv from ..externals.joblib import Parallel, delayed from ..externals import six from ..externals.six.moves import xrange @@ -1120,10 +1120,10 @@ def fit(self, X, y): path_params['copy_X'] = False # init cross-validation generator - cv = check_cv(self.cv, X) + cv = check_cv(self.cv) # Compute path for all folds and compute MSE to get the best alpha - folds = list(cv) + folds = list(cv.split(X)) best_mse = np.inf # We do a double for loop folded in one, in order to be able to diff --git a/sklearn/linear_model/least_angle.py b/sklearn/linear_model/least_angle.py index 8fa0a021bcc32..699ce2b315188 100644 --- a/sklearn/linear_model/least_angle.py +++ b/sklearn/linear_model/least_angle.py @@ -22,7 +22,7 @@ from .base import LinearModel from ..base import RegressorMixin from ..utils import arrayfuncs, as_float_array, check_X_y -from ..cross_validation import check_cv +from ..model_selection import check_cv from ..exceptions import ConvergenceWarning from ..externals.joblib import Parallel, delayed from ..externals.six.moves import xrange @@ -1079,7 +1079,7 @@ def fit(self, X, y): y = as_float_array(y, copy=self.copy_X) # init cross-validation generator - cv = check_cv(self.cv, X, y, classifier=False) + cv = check_cv(self.cv, classifier=False) Gram = 'auto' if self.precompute else None @@ -1089,7 +1089,7 @@ def fit(self, X, y): method=self.method, verbose=max(0, self.verbose - 1), normalize=self.normalize, fit_intercept=self.fit_intercept, max_iter=self.max_iter, eps=self.eps, positive=self.positive) - for train, test in cv) + for train, test in cv.split(X, y)) all_alphas = np.concatenate(list(zip(*cv_paths))[0]) # Unique also sorts all_alphas = np.unique(all_alphas) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 75740c4cc7305..c24db357746a7 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1,3 +1,4 @@ + """ Logistic Regression """ @@ -32,7 +33,7 @@ from ..utils.fixes import expit from ..utils.multiclass import check_classification_targets from ..externals.joblib import Parallel, delayed -from ..cross_validation import check_cv +from ..model_selection import check_cv from ..externals import six from ..metrics import SCORERS @@ -1309,7 +1310,7 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator, cv : integer or cross-validation generator The default cross-validation generator used is Stratified K-Folds. If an integer is provided, then it is the number of folds used. - See the module :mod:`sklearn.cross_validation` module for the + See the module :mod:`sklearn.model_selection` module for the list of possible cross-validation objects. penalty : str, 'l1' or 'l2' @@ -1506,8 +1507,8 @@ def fit(self, X, y, sample_weight=None): check_consistent_length(X, y) # init cross-validation generator - cv = check_cv(self.cv, X, y, classifier=True) - folds = list(cv) + cv = check_cv(self.cv, y, classifier=True) + folds = list(cv.split(X, y)) self._enc = LabelEncoder() self._enc.fit(y) diff --git a/sklearn/linear_model/omp.py b/sklearn/linear_model/omp.py index 3b87b9cf6c410..92f6cb69b238d 100644 --- a/sklearn/linear_model/omp.py +++ b/sklearn/linear_model/omp.py @@ -15,7 +15,7 @@ from .base import LinearModel, _pre_fit from ..base import RegressorMixin from ..utils import as_float_array, check_array, check_X_y -from ..cross_validation import check_cv +from ..model_selection import check_cv from ..externals.joblib import Parallel, delayed import scipy @@ -835,7 +835,7 @@ def fit(self, X, y): X, y = check_X_y(X, y, y_numeric=True, ensure_min_features=2, estimator=self) X = as_float_array(X, copy=False, force_all_finite=False) - cv = check_cv(self.cv, X, y, classifier=False) + cv = check_cv(self.cv, classifier=False) max_iter = (min(max(int(0.1 * X.shape[1]), 5), X.shape[1]) if not self.max_iter else self.max_iter) @@ -843,7 +843,7 @@ def fit(self, X, y): delayed(_omp_path_residues)( X[train], y[train], X[test], y[test], self.copy, self.fit_intercept, self.normalize, max_iter) - for train, test in cv) + for train, test in cv.split(X)) min_early_stop = min(fold.shape[0] for fold in cv_paths) mse_folds = np.array([(fold[:min_early_stop] ** 2).mean(axis=1) diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py index 5f3cbc10f57ef..ed70d8ca9572d 100644 --- a/sklearn/linear_model/ridge.py +++ b/sklearn/linear_model/ridge.py @@ -28,7 +28,7 @@ from ..utils import compute_sample_weight from ..utils import column_or_1d from ..preprocessing import LabelBinarizer -from ..grid_search import GridSearchCV +from ..model_selection import GridSearchCV from ..externals import six from ..metrics.scorer import check_scoring diff --git a/sklearn/linear_model/tests/test_least_angle.py b/sklearn/linear_model/tests/test_least_angle.py index 9692edaaa2b2d..f9e8a6cdd5918 100644 --- a/sklearn/linear_model/tests/test_least_angle.py +++ b/sklearn/linear_model/tests/test_least_angle.py @@ -3,7 +3,7 @@ import numpy as np from scipy import linalg -from sklearn.cross_validation import train_test_split +from sklearn.model_selection import train_test_split from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_less diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 31175f3035be6..13e55278f5522 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -24,7 +24,7 @@ _logistic_loss_and_grad, _logistic_grad_hess, _multinomial_grad_hess, _logistic_loss, ) -from sklearn.cross_validation import StratifiedKFold +from sklearn.model_selection import StratifiedKFold from sklearn.datasets import load_iris, make_classification from sklearn.metrics import log_loss @@ -454,16 +454,24 @@ def test_ovr_multinomial_iris(): train, target = iris.data, iris.target n_samples, n_features = train.shape - # Use pre-defined fold as folds generated for different y - cv = StratifiedKFold(target, 3) - clf = LogisticRegressionCV(cv=cv) + # The cv indices from stratified kfold (where stratification is done based + # on the fine-grained iris classes, i.e, before the classes 0 and 1 are + # conflated) is used for both clf and clf1 + cv = StratifiedKFold(3) + precomputed_folds = list(cv.split(train, target)) + + # Train clf on the original dataset where classes 0 and 1 are separated + clf = LogisticRegressionCV(cv=precomputed_folds) clf.fit(train, target) - clf1 = LogisticRegressionCV(cv=cv) + # Conflate classes 0 and 1 and train clf1 on this modifed dataset + clf1 = LogisticRegressionCV(cv=precomputed_folds) target_copy = target.copy() target_copy[target_copy == 0] = 1 clf1.fit(train, target_copy) + # Ensure that what OvR learns for class2 is same regardless of whether + # classes 0 and 1 are separated or not assert_array_almost_equal(clf.scores_[2], clf1.scores_[2]) assert_array_almost_equal(clf.intercept_[2:], clf1.intercept_) assert_array_almost_equal(clf.coef_[2][np.newaxis, :], clf1.coef_) diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index 6f0b00b1dc54f..0f66badaf26b5 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -29,9 +29,8 @@ from sklearn.linear_model.ridge import _solve_cholesky_kernel from sklearn.datasets import make_regression -from sklearn.grid_search import GridSearchCV - -from sklearn.cross_validation import KFold +from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import KFold diabetes = datasets.load_diabetes() @@ -358,8 +357,6 @@ def _test_ridge_loo(filter_): def _test_ridge_cv(filter_): - n_samples = X_diabetes.shape[0] - ridge_cv = RidgeCV() ridge_cv.fit(filter_(X_diabetes), y_diabetes) ridge_cv.predict(filter_(X_diabetes)) @@ -367,7 +364,7 @@ def _test_ridge_cv(filter_): assert_equal(len(ridge_cv.coef_.shape), 1) assert_equal(type(ridge_cv.intercept_), np.float64) - cv = KFold(n_samples, 5) + cv = KFold(5) ridge_cv.set_params(cv=cv) ridge_cv.fit(filter_(X_diabetes), y_diabetes) ridge_cv.predict(filter_(X_diabetes)) @@ -406,8 +403,7 @@ def _test_ridge_classifiers(filter_): y_pred = clf.predict(filter_(X_iris)) assert_greater(np.mean(y_iris == y_pred), .79) - n_samples = X_iris.shape[0] - cv = KFold(n_samples, 5) + cv = KFold(5) clf = RidgeClassifierCV(cv=cv) clf.fit(filter_(X_iris), y_iris) y_pred = clf.predict(filter_(X_iris)) @@ -571,7 +567,7 @@ def test_ridgecv_sample_weight(): X = rng.randn(n_samples, n_features) sample_weight = 1 + rng.rand(n_samples) - cv = KFold(n_samples, 5) + cv = KFold(5) ridgecv = RidgeCV(alphas=alphas, cv=cv) ridgecv.fit(X, y, sample_weight=sample_weight) diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 6628fb40d4e75..3ab43a2d10355 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -4,9 +4,9 @@ arbitrary score functions. A scorer object is a callable that can be passed to -:class:`sklearn.grid_search.GridSearchCV` or -:func:`sklearn.cross_validation.cross_val_score` as the ``scoring`` parameter, -to specify how a model should be evaluated. +:class:`sklearn.model_selection.GridSearchCV` or +:func:`sklearn.model_selection.cross_val_score` as the ``scoring`` +parameter, to specify how a model should be evaluated. The signature of the call is ``(estimator, X, y)`` where ``estimator`` is the model to be evaluated, ``X`` is the test data and ``y`` is the @@ -294,7 +294,7 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, >>> ftwo_scorer = make_scorer(fbeta_score, beta=2) >>> ftwo_scorer make_scorer(fbeta_score, beta=2) - >>> from sklearn.grid_search import GridSearchCV + >>> from sklearn.model_selection import GridSearchCV >>> from sklearn.svm import LinearSVC >>> grid = GridSearchCV(LinearSVC(), param_grid={'C': [1, 10]}, ... scoring=ftwo_scorer) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 091c0592570a6..9dcdd2f6c415b 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -27,8 +27,8 @@ from sklearn.datasets import make_classification from sklearn.datasets import make_multilabel_classification from sklearn.datasets import load_diabetes -from sklearn.cross_validation import train_test_split, cross_val_score -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import train_test_split, cross_val_score +from sklearn.model_selection import GridSearchCV from sklearn.multiclass import OneVsRestClassifier diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py new file mode 100644 index 0000000000000..caf0fe70ff493 --- /dev/null +++ b/sklearn/model_selection/__init__.py @@ -0,0 +1,51 @@ +from ._split import BaseCrossValidator +from ._split import KFold +from ._split import LabelKFold +from ._split import StratifiedKFold +from ._split import LeaveOneLabelOut +from ._split import LeaveOneOut +from ._split import LeavePLabelOut +from ._split import LeavePOut +from ._split import ShuffleSplit +from ._split import LabelShuffleSplit +from ._split import StratifiedShuffleSplit +from ._split import PredefinedSplit +from ._split import train_test_split +from ._split import check_cv + +from ._validation import cross_val_score +from ._validation import cross_val_predict +from ._validation import learning_curve +from ._validation import permutation_test_score +from ._validation import validation_curve + +from ._search import GridSearchCV +from ._search import RandomizedSearchCV +from ._search import ParameterGrid +from ._search import ParameterSampler +from ._search import fit_grid_point + +__all__ = ('BaseCrossValidator', + 'GridSearchCV', + 'KFold', + 'LabelKFold', + 'LabelShuffleSplit', + 'LeaveOneLabelOut', + 'LeaveOneOut', + 'LeavePLabelOut', + 'LeavePOut', + 'ParameterGrid', + 'ParameterSampler', + 'PredefinedSplit', + 'RandomizedSearchCV', + 'ShuffleSplit', + 'StratifiedKFold', + 'StratifiedShuffleSplit', + 'check_cv', + 'cross_val_predict', + 'cross_val_score', + 'fit_grid_point', + 'learning_curve', + 'permutation_test_score', + 'train_test_split', + 'validation_curve') diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py new file mode 100644 index 0000000000000..d124f2f5cb348 --- /dev/null +++ b/sklearn/model_selection/_search.py @@ -0,0 +1,996 @@ +""" +The :mod:`sklearn.model_selection._search` includes utilities to fine-tune the +parameters of an estimator. +""" +from __future__ import print_function + +# Author: Alexandre Gramfort , +# Gael Varoquaux +# Andreas Mueller +# Olivier Grisel +# License: BSD 3 clause + +from abc import ABCMeta, abstractmethod +from collections import Mapping, namedtuple, Sized +from functools import partial, reduce +from itertools import product +import operator +import warnings + +import numpy as np + +from ..base import BaseEstimator, is_classifier, clone +from ..base import MetaEstimatorMixin, ChangedBehaviorWarning +from ._split import check_cv +from ._validation import _fit_and_score +from ..externals.joblib import Parallel, delayed +from ..externals import six +from ..utils import check_random_state +from ..utils.random import sample_without_replacement +from ..utils.validation import _num_samples, indexable +from ..utils.metaestimators import if_delegate_has_method +from ..metrics.scorer import check_scoring + + +__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point', + 'ParameterSampler', 'RandomizedSearchCV'] + + +class ParameterGrid(object): + """Grid of parameters with a discrete number of values for each. + + Can be used to iterate over parameter value combinations with the + Python built-in function iter. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + param_grid : dict of string to sequence, or sequence of such + The parameter grid to explore, as a dictionary mapping estimator + parameters to sequences of allowed values. + + An empty dict signifies default parameters. + + A sequence of dicts signifies a sequence of grids to search, and is + useful to avoid exploring parameter combinations that make no sense + or have no effect. See the examples below. + + Examples + -------- + >>> from sklearn.model_selection import ParameterGrid + >>> param_grid = {'a': [1, 2], 'b': [True, False]} + >>> list(ParameterGrid(param_grid)) == ( + ... [{'a': 1, 'b': True}, {'a': 1, 'b': False}, + ... {'a': 2, 'b': True}, {'a': 2, 'b': False}]) + True + + >>> grid = [{'kernel': ['linear']}, {'kernel': ['rbf'], 'gamma': [1, 10]}] + >>> list(ParameterGrid(grid)) == [{'kernel': 'linear'}, + ... {'kernel': 'rbf', 'gamma': 1}, + ... {'kernel': 'rbf', 'gamma': 10}] + True + >>> ParameterGrid(grid)[1] == {'kernel': 'rbf', 'gamma': 1} + True + + See also + -------- + :class:`GridSearchCV`: + Uses :class:`ParameterGrid` to perform a full parallelized parameter + search. + """ + + def __init__(self, param_grid): + if isinstance(param_grid, Mapping): + # wrap dictionary in a singleton list to support either dict + # or list of dicts + param_grid = [param_grid] + self.param_grid = param_grid + + def __iter__(self): + """Iterate over the points in the grid. + + Returns + ------- + params : iterator over dict of string to any + Yields dictionaries mapping each estimator parameter to one of its + allowed values. + """ + for p in self.param_grid: + # Always sort the keys of a dictionary, for reproducibility + items = sorted(p.items()) + if not items: + yield {} + else: + keys, values = zip(*items) + for v in product(*values): + params = dict(zip(keys, v)) + yield params + + def __len__(self): + """Number of points on the grid.""" + # Product function that can handle iterables (np.product can't). + product = partial(reduce, operator.mul) + return sum(product(len(v) for v in p.values()) if p else 1 + for p in self.param_grid) + + def __getitem__(self, ind): + """Get the parameters that would be ``ind``th in iteration + + Parameters + ---------- + ind : int + The iteration index + + Returns + ------- + params : dict of string to any + Equal to list(self)[ind] + """ + # This is used to make discrete sampling without replacement memory + # efficient. + for sub_grid in self.param_grid: + # XXX: could memoize information used here + if not sub_grid: + if ind == 0: + return {} + else: + ind -= 1 + continue + + # Reverse so most frequent cycling parameter comes first + keys, values_lists = zip(*sorted(sub_grid.items())[::-1]) + sizes = [len(v_list) for v_list in values_lists] + total = np.product(sizes) + + if ind >= total: + # Try the next grid + ind -= total + else: + out = {} + for key, v_list, n in zip(keys, values_lists, sizes): + ind, offset = divmod(ind, n) + out[key] = v_list[offset] + return out + + raise IndexError('ParameterGrid index out of range') + + +class ParameterSampler(object): + """Generator on parameters sampled from given distributions. + + Non-deterministic iterable over random candidate combinations for hyper- + parameter search. If all parameters are presented as a list, + sampling without replacement is performed. If at least one parameter + is given as a distribution, sampling with replacement is used. + It is highly recommended to use continuous distributions for continuous + parameters. + + Note that as of SciPy 0.12, the ``scipy.stats.distributions`` do not accept + a custom RNG instance and always use the singleton RNG from + ``numpy.random``. Hence setting ``random_state`` will not guarantee a + deterministic iteration whenever ``scipy.stats`` distributions are used to + define the parameter search space. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + param_distributions : dict + Dictionary where the keys are parameters and values + are distributions from which a parameter is to be sampled. + Distributions either have to provide a ``rvs`` function + to sample from them, or can be given as a list of values, + where a uniform distribution is assumed. + + n_iter : integer + Number of parameter settings that are produced. + + random_state : int or RandomState + Pseudo random number generator state used for random uniform sampling + from lists of possible values instead of scipy.stats distributions. + + Returns + ------- + params : dict of string to any + **Yields** dictionaries mapping each estimator parameter to + as sampled value. + + Examples + -------- + >>> from sklearn.model_selection import ParameterSampler + >>> from scipy.stats.distributions import expon + >>> import numpy as np + >>> np.random.seed(0) + >>> param_grid = {'a':[1, 2], 'b': expon()} + >>> param_list = list(ParameterSampler(param_grid, n_iter=4)) + >>> rounded_list = [dict((k, round(v, 6)) for (k, v) in d.items()) + ... for d in param_list] + >>> rounded_list == [{'b': 0.89856, 'a': 1}, + ... {'b': 0.923223, 'a': 1}, + ... {'b': 1.878964, 'a': 2}, + ... {'b': 1.038159, 'a': 2}] + True + """ + def __init__(self, param_distributions, n_iter, random_state=None): + self.param_distributions = param_distributions + self.n_iter = n_iter + self.random_state = random_state + + def __iter__(self): + # check if all distributions are given as lists + # in this case we want to sample without replacement + all_lists = np.all([not hasattr(v, "rvs") + for v in self.param_distributions.values()]) + rnd = check_random_state(self.random_state) + + if all_lists: + # look up sampled parameter settings in parameter grid + param_grid = ParameterGrid(self.param_distributions) + grid_size = len(param_grid) + + if grid_size < self.n_iter: + raise ValueError( + "The total space of parameters %d is smaller " + "than n_iter=%d." % (grid_size, self.n_iter) + + " For exhaustive searches, use GridSearchCV.") + for i in sample_without_replacement(grid_size, self.n_iter, + random_state=rnd): + yield param_grid[i] + + else: + # Always sort the keys of a dictionary, for reproducibility + items = sorted(self.param_distributions.items()) + for _ in six.moves.range(self.n_iter): + params = dict() + for k, v in items: + if hasattr(v, "rvs"): + params[k] = v.rvs() + else: + params[k] = v[rnd.randint(len(v))] + yield params + + def __len__(self): + """Number of points that will be sampled.""" + return self.n_iter + + +def fit_grid_point(X, y, estimator, parameters, train, test, scorer, + verbose, error_score='raise', **fit_params): + """Run fit on one set of parameters. + + Parameters + ---------- + X : array-like, sparse matrix or list + Input data. + + y : array-like or None + Targets for input data. + + estimator : estimator object + A object of that type is instantiated for each grid point. + This is assumed to implement the scikit-learn estimator interface. + Either estimator needs to provide a ``score`` function, + or ``scoring`` must be passed. + + parameters : dict + Parameters to be set on estimator for this grid point. + + train : ndarray, dtype int or bool + Boolean mask or indices for training set. + + test : ndarray, dtype int or bool + Boolean mask or indices for test set. + + scorer : callable or None. + If provided must be a scorer callable object / function with signature + ``scorer(estimator, X, y)``. + + verbose : int + Verbosity level. + + **fit_params : kwargs + Additional parameter passed to the fit function of the estimator. + + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + + Returns + ------- + score : float + Score of this parameter setting on given training / test split. + + parameters : dict + The parameters that have been evaluated. + + n_samples_test : int + Number of test samples in this split. + """ + score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train, + test, verbose, parameters, + fit_params, error_score) + return score, parameters, n_samples_test + + +def _check_param_grid(param_grid): + if hasattr(param_grid, 'items'): + param_grid = [param_grid] + + for p in param_grid: + for v in p.values(): + if isinstance(v, np.ndarray) and v.ndim > 1: + raise ValueError("Parameter array should be one-dimensional.") + + check = [isinstance(v, k) for k in (list, tuple, np.ndarray)] + if True not in check: + raise ValueError("Parameter values should be a list.") + + if len(v) == 0: + raise ValueError("Parameter values should be a non-empty " + "list.") + + +class _CVScoreTuple (namedtuple('_CVScoreTuple', + ('parameters', + 'mean_validation_score', + 'cv_validation_scores'))): + # A raw namedtuple is very memory efficient as it packs the attributes + # in a struct to get rid of the __dict__ of attributes in particular it + # does not copy the string for the keys on each instance. + # By deriving a namedtuple class just to introduce the __repr__ method we + # would also reintroduce the __dict__ on the instance. By telling the + # Python interpreter that this subclass uses static __slots__ instead of + # dynamic attributes. Furthermore we don't need any additional slot in the + # subclass so we set __slots__ to the empty tuple. + __slots__ = () + + def __repr__(self): + """Simple custom repr to summarize the main info""" + return "mean: {0:.5f}, std: {1:.5f}, params: {2}".format( + self.mean_validation_score, + np.std(self.cv_validation_scores), + self.parameters) + + +class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, + MetaEstimatorMixin)): + """Base class for hyper parameter search with cross-validation.""" + + @abstractmethod + def __init__(self, estimator, scoring=None, + fit_params=None, n_jobs=1, iid=True, + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', + error_score='raise'): + + self.scoring = scoring + self.estimator = estimator + self.n_jobs = n_jobs + self.fit_params = fit_params if fit_params is not None else {} + self.iid = iid + self.refit = refit + self.cv = cv + self.verbose = verbose + self.pre_dispatch = pre_dispatch + self.error_score = error_score + + @property + def _estimator_type(self): + return self.estimator._estimator_type + + def score(self, X, y=None): + """Returns the score on the given data, if the estimator has been refit. + + This uses the score defined by ``scoring`` where provided, and the + ``best_estimator_.score`` method otherwise. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Input data, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape = [n_samples] or [n_samples, n_output], optional + Target relative to X for classification or regression; + None for unsupervised learning. + + Returns + ------- + score : float + + Notes + ----- + * The long-standing behavior of this method changed in version 0.16. + * It no longer uses the metric provided by ``estimator.score`` if the + ``scoring`` parameter was set when fitting. + + """ + if self.scorer_ is None: + raise ValueError("No score function explicitly defined, " + "and the estimator doesn't provide one %s" + % self.best_estimator_) + if self.scoring is not None and hasattr(self.best_estimator_, 'score'): + warnings.warn("The long-standing behavior to use the estimator's " + "score function in {0}.score has changed. The " + "scoring parameter is now used." + "".format(self.__class__.__name__), + ChangedBehaviorWarning) + return self.scorer_(self.best_estimator_, X, y) + + @if_delegate_has_method(delegate='estimator') + def predict(self, X): + """Call predict on the estimator with the best found parameters. + + Only available if ``refit=True`` and the underlying estimator supports + ``predict``. + + Parameters + ----------- + X : indexable, length n_samples + Must fulfill the input assumptions of the + underlying estimator. + + """ + return self.best_estimator_.predict(X) + + @if_delegate_has_method(delegate='estimator') + def predict_proba(self, X): + """Call predict_proba on the estimator with the best found parameters. + + Only available if ``refit=True`` and the underlying estimator supports + ``predict_proba``. + + Parameters + ----------- + X : indexable, length n_samples + Must fulfill the input assumptions of the + underlying estimator. + + """ + return self.best_estimator_.predict_proba(X) + + @if_delegate_has_method(delegate='estimator') + def predict_log_proba(self, X): + """Call predict_log_proba on the estimator with the best found parameters. + + Only available if ``refit=True`` and the underlying estimator supports + ``predict_log_proba``. + + Parameters + ----------- + X : indexable, length n_samples + Must fulfill the input assumptions of the + underlying estimator. + + """ + return self.best_estimator_.predict_log_proba(X) + + @if_delegate_has_method(delegate='estimator') + def decision_function(self, X): + """Call decision_function on the estimator with the best found parameters. + + Only available if ``refit=True`` and the underlying estimator supports + ``decision_function``. + + Parameters + ----------- + X : indexable, length n_samples + Must fulfill the input assumptions of the + underlying estimator. + + """ + return self.best_estimator_.decision_function(X) + + @if_delegate_has_method(delegate='estimator') + def transform(self, X): + """Call transform on the estimator with the best found parameters. + + Only available if the underlying estimator supports ``transform`` and + ``refit=True``. + + Parameters + ----------- + X : indexable, length n_samples + Must fulfill the input assumptions of the + underlying estimator. + + """ + return self.best_estimator_.transform(X) + + @if_delegate_has_method(delegate='estimator') + def inverse_transform(self, Xt): + """Call inverse_transform on the estimator with the best found parameters. + + Only available if the underlying estimator implements + ``inverse_transform`` and ``refit=True``. + + Parameters + ----------- + Xt : indexable, length n_samples + Must fulfill the input assumptions of the + underlying estimator. + + """ + return self.best_estimator_.transform(Xt) + + def _fit(self, X, y, labels, parameter_iterable): + """Actual fitting, performing the search over parameters.""" + + estimator = self.estimator + cv = check_cv(self.cv, y, classifier=is_classifier(estimator)) + self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) + + n_samples = _num_samples(X) + X, y, labels = indexable(X, y, labels) + + if y is not None: + if len(y) != n_samples: + raise ValueError('Target variable (y) has a different number ' + 'of samples (%i) than data (X: %i samples)' + % (len(y), n_samples)) + n_splits = cv.get_n_splits(X, y, labels) + + if self.verbose > 0 and isinstance(parameter_iterable, Sized): + n_candidates = len(parameter_iterable) + print("Fitting {0} folds for each of {1} candidates, totalling" + " {2} fits".format(n_splits, n_candidates, + n_candidates * n_splits)) + + base_estimator = clone(self.estimator) + + pre_dispatch = self.pre_dispatch + + out = Parallel( + n_jobs=self.n_jobs, verbose=self.verbose, + pre_dispatch=pre_dispatch + )(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_, + train, test, self.verbose, parameters, + self.fit_params, return_parameters=True, + error_score=self.error_score) + for parameters in parameter_iterable + for train, test in cv.split(X, y, labels)) + + # Out is a list of triplet: score, estimator, n_test_samples + n_fits = len(out) + + scores = list() + grid_scores = list() + for grid_start in range(0, n_fits, n_splits): + n_test_samples = 0 + score = 0 + all_scores = [] + for this_score, this_n_test_samples, _, parameters in \ + out[grid_start:grid_start + n_splits]: + all_scores.append(this_score) + if self.iid: + this_score *= this_n_test_samples + n_test_samples += this_n_test_samples + score += this_score + if self.iid: + score /= float(n_test_samples) + else: + score /= float(n_splits) + scores.append((score, parameters)) + # TODO: shall we also store the test_fold_sizes? + grid_scores.append(_CVScoreTuple( + parameters, + score, + np.array(all_scores))) + # Store the computed scores + self.grid_scores_ = grid_scores + + # Find the best parameters by comparing on the mean validation score: + # note that `sorted` is deterministic in the way it breaks ties + best = sorted(grid_scores, key=lambda x: x.mean_validation_score, + reverse=True)[0] + self.best_params_ = best.parameters + self.best_score_ = best.mean_validation_score + + if self.refit: + # fit the best estimator using the entire dataset + # clone first to work around broken estimators + best_estimator = clone(base_estimator).set_params( + **best.parameters) + if y is not None: + best_estimator.fit(X, y, **self.fit_params) + else: + best_estimator.fit(X, **self.fit_params) + self.best_estimator_ = best_estimator + return self + + +class GridSearchCV(BaseSearchCV): + """Exhaustive search over specified parameter values for an estimator. + + Important members are fit, predict. + + GridSearchCV implements a "fit" and a "score" method. + It also implements "predict", "predict_proba", "decision_function", + "transform" and "inverse_transform" if they are implemented in the + estimator used. + + The parameters of the estimator used to apply these methods are optimized + by cross-validated grid-search over a parameter grid. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object. + This is assumed to implement the scikit-learn estimator interface. + Either estimator needs to provide a ``score`` function, + or ``scoring`` must be passed. + + param_grid : dict or list of dictionaries + Dictionary with parameters names (string) as keys and lists of + parameter settings to try as values, or a list of such + dictionaries, in which case the grids spanned by each dictionary + in the list are explored. This enables searching over any sequence + of parameter settings. + + scoring : string, callable or None, default=None + A string (see model evaluation documentation) or + a scorer callable object / function with signature + ``scorer(estimator, X, y)``. + If ``None``, the ``score`` method of the estimator is used. + + fit_params : dict, optional + Parameters to pass to the fit method. + + n_jobs : int, default=1 + Number of jobs to run in parallel. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediately + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + iid : boolean, default=True + If True, the data is assumed to be identically distributed across + the folds, and the loss minimized is the total loss per sample, + and not the mean loss across the folds. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, ``StratifiedKFold`` is used for classification + tasks, when ``y`` is binary or multiclass. + + See the :mod:`sklearn.model_selection` module for the list of + cross-validation strategies that can be used here. + + Also refer :ref:`cross-validation documentation ` + + refit : boolean, default=True + Refit the best estimator with the entire dataset. + If "False", it is impossible to make predictions using + this GridSearchCV instance after fitting. + + verbose : integer + Controls the verbosity: the higher, the more messages. + + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + + + Examples + -------- + >>> from sklearn import svm, datasets + >>> from sklearn.model_selection import GridSearchCV + >>> iris = datasets.load_iris() + >>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]} + >>> svr = svm.SVC() + >>> clf = GridSearchCV(svr, parameters) + >>> clf.fit(iris.data, iris.target) + ... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS + GridSearchCV(cv=None, error_score=..., + estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=..., + decision_function_shape=None, degree=..., gamma=..., + kernel='rbf', max_iter=-1, probability=False, + random_state=None, shrinking=True, tol=..., + verbose=False), + fit_params={}, iid=..., n_jobs=1, + param_grid=..., pre_dispatch=..., refit=..., + scoring=..., verbose=...) + + + Attributes + ---------- + grid_scores_ : list of named tuples + Contains scores for all parameter combinations in param_grid. + Each entry corresponds to one parameter setting. + Each named tuple has the attributes: + + * ``parameters``, a dict of parameter settings + * ``mean_validation_score``, the mean score over the + cross-validation folds + * ``cv_validation_scores``, the list of scores for each fold + + best_estimator_ : estimator + Estimator that was chosen by the search, i.e. estimator + which gave highest score (or smallest loss if specified) + on the left out data. Not available if refit=False. + + best_score_ : float + Score of best_estimator on the left out data. + + best_params_ : dict + Parameter setting that gave the best results on the hold out data. + + scorer_ : function + Scorer function used on the held out data to choose the best + parameters for the model. + + Notes + ------ + The parameters selected are those that maximize the score of the left out + data, unless an explicit score is passed in which case it is used instead. + + If `n_jobs` was set to a value higher than one, the data is copied for each + point in the grid (and not `n_jobs` times). This is done for efficiency + reasons if individual jobs take very little time, but may raise errors if + the dataset is large and not enough memory is available. A workaround in + this case is to set `pre_dispatch`. Then, the memory is copied only + `pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 * + n_jobs`. + + See Also + --------- + :class:`ParameterGrid`: + generates all the combinations of a an hyperparameter grid. + + :func:`sklearn.model_selection.train_test_split`: + utility function to split the data into a development set usable + for fitting a GridSearchCV instance and an evaluation set for + its final evaluation. + + :func:`sklearn.metrics.make_scorer`: + Make a scorer from a performance metric or loss function. + + """ + + def __init__(self, estimator, param_grid, scoring=None, fit_params=None, + n_jobs=1, iid=True, refit=True, cv=None, verbose=0, + pre_dispatch='2*n_jobs', error_score='raise'): + + super(GridSearchCV, self).__init__( + estimator=estimator, scoring=scoring, fit_params=fit_params, + n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, + pre_dispatch=pre_dispatch, error_score=error_score) + self.param_grid = param_grid + _check_param_grid(param_grid) + + def fit(self, X, y=None, labels=None): + """Run fit with all sets of parameters. + + Parameters + ---------- + + X : array-like, shape = [n_samples, n_features] + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape = [n_samples] or [n_samples, n_output], optional + Target relative to X for classification or regression; + None for unsupervised learning. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + """ + return self._fit(X, y, labels, ParameterGrid(self.param_grid)) + + +class RandomizedSearchCV(BaseSearchCV): + """Randomized search on hyper parameters. + + RandomizedSearchCV implements a "fit" and a "score" method. + It also implements "predict", "predict_proba", "decision_function", + "transform" and "inverse_transform" if they are implemented in the + estimator used. + + The parameters of the estimator used to apply these methods are optimized + by cross-validated search over parameter settings. + + In contrast to GridSearchCV, not all parameter values are tried out, but + rather a fixed number of parameter settings is sampled from the specified + distributions. The number of parameter settings that are tried is + given by n_iter. + + If all parameters are presented as a list, + sampling without replacement is performed. If at least one parameter + is given as a distribution, sampling with replacement is used. + It is highly recommended to use continuous distributions for continuous + parameters. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object. + A object of that type is instantiated for each grid point. + This is assumed to implement the scikit-learn estimator interface. + Either estimator needs to provide a ``score`` function, + or ``scoring`` must be passed. + + param_distributions : dict + Dictionary with parameters names (string) as keys and distributions + or lists of parameters to try. Distributions must provide a ``rvs`` + method for sampling (such as those from scipy.stats.distributions). + If a list is given, it is sampled uniformly. + + n_iter : int, default=10 + Number of parameter settings that are sampled. n_iter trades + off runtime vs quality of the solution. + + scoring : string, callable or None, default=None + A string (see model evaluation documentation) or + a scorer callable object / function with signature + ``scorer(estimator, X, y)``. + If ``None``, the ``score`` method of the estimator is used. + + fit_params : dict, optional + Parameters to pass to the fit method. + + n_jobs : int, default=1 + Number of jobs to run in parallel. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediately + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + iid : boolean, default=True + If True, the data is assumed to be identically distributed across + the folds, and the loss minimized is the total loss per sample, + and not the mean loss across the folds. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, ``StratifiedKFold`` is used for classification + tasks, when ``y`` is binary or multiclass. + + See the :mod:`sklearn.model_selection` module for the list of + cross-validation strategies that can be used here. + + Also refer :ref:`cross-validation documentation ` + + refit : boolean, default=True + Refit the best estimator with the entire dataset. + If "False", it is impossible to make predictions using + this RandomizedSearchCV instance after fitting. + + verbose : integer + Controls the verbosity: the higher, the more messages. + + random_state : int or RandomState + Pseudo random number generator state used for random uniform sampling + from lists of possible values instead of scipy.stats distributions. + + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + + Attributes + ---------- + grid_scores_ : list of named tuples + Contains scores for all parameter combinations in param_grid. + Each entry corresponds to one parameter setting. + Each named tuple has the attributes: + + * ``parameters``, a dict of parameter settings + * ``mean_validation_score``, the mean score over the + cross-validation folds + * ``cv_validation_scores``, the list of scores for each fold + + best_estimator_ : estimator + Estimator that was chosen by the search, i.e. estimator + which gave highest score (or smallest loss if specified) + on the left out data. Not available if refit=False. + + best_score_ : float + Score of best_estimator on the left out data. + + best_params_ : dict + Parameter setting that gave the best results on the hold out data. + + Notes + ----- + The parameters selected are those that maximize the score of the held-out + data, according to the scoring parameter. + + If `n_jobs` was set to a value higher than one, the data is copied for each + parameter setting(and not `n_jobs` times). This is done for efficiency + reasons if individual jobs take very little time, but may raise errors if + the dataset is large and not enough memory is available. A workaround in + this case is to set `pre_dispatch`. Then, the memory is copied only + `pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 * + n_jobs`. + + See Also + -------- + :class:`GridSearchCV`: + Does exhaustive search over a grid of parameters. + + :class:`ParameterSampler`: + A generator over parameter settins, constructed from + param_distributions. + + """ + + def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, + fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, + verbose=0, pre_dispatch='2*n_jobs', random_state=None, + error_score='raise'): + + self.param_distributions = param_distributions + self.n_iter = n_iter + self.random_state = random_state + super(RandomizedSearchCV, self).__init__( + estimator=estimator, scoring=scoring, fit_params=fit_params, + n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, + pre_dispatch=pre_dispatch, error_score=error_score) + + def fit(self, X, y=None, labels=None): + """Run fit on the estimator with randomly drawn parameters. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Training vector, where n_samples in the number of samples and + n_features is the number of features. + + y : array-like, shape = [n_samples] or [n_samples, n_output], optional + Target relative to X for classification or regression; + None for unsupervised learning. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + """ + sampled_params = ParameterSampler(self.param_distributions, + self.n_iter, + random_state=self.random_state) + return self._fit(X, y, labels, sampled_params) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py new file mode 100644 index 0000000000000..8922cb8356d4d --- /dev/null +++ b/sklearn/model_selection/_split.py @@ -0,0 +1,1530 @@ +""" +The :mod:`sklearn.model_selection._split` module includes classes and +functions to split the data based on a preset strategy. +""" + +# Author: Alexandre Gramfort , +# Gael Varoquaux , +# Olivier Girsel +# Raghav R V +# License: BSD 3 clause + + +from __future__ import print_function +from __future__ import division + +import warnings +import inspect +from itertools import chain, combinations +from collections import Iterable +from math import ceil, floor +import numbers +from abc import ABCMeta, abstractmethod + +import numpy as np + +from scipy.misc import comb +from ..utils import indexable, check_random_state, safe_indexing +from ..utils.validation import _num_samples, column_or_1d +from ..utils.multiclass import type_of_target +from ..externals.six import with_metaclass +from ..externals.six.moves import zip +from ..utils.fixes import bincount +from ..base import _pprint +from ..gaussian_process.kernels import Kernel as GPKernel + +__all__ = ['BaseCrossValidator', + 'KFold', + 'LabelKFold', + 'LeaveOneLabelOut', + 'LeaveOneOut', + 'LeavePLabelOut', + 'LeavePOut', + 'ShuffleSplit', + 'LabelShuffleSplit', + 'StratifiedKFold', + 'StratifiedShuffleSplit', + 'PredefinedSplit', + 'train_test_split', + 'check_cv'] + + +class BaseCrossValidator(with_metaclass(ABCMeta)): + """Base class for all cross-validators + + Implementations must define `_iter_test_masks` or `_iter_test_indices`. + """ + + def split(self, X, y=None, labels=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples + and n_features is the number of features. + + y : array-like, shape (n_samples,) + The target variable for supervised learning problems. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + X, y, labels = indexable(X, y, labels) + indices = np.arange(_num_samples(X)) + for test_index in self._iter_test_masks(X, y, labels): + train_index = indices[np.logical_not(test_index)] + test_index = indices[test_index] + yield train_index, test_index + + # Since subclasses must implement either _iter_test_masks or + # _iter_test_indices, neither can be abstract. + def _iter_test_masks(self, X=None, y=None, labels=None): + """Generates boolean masks corresponding to test sets. + + By default, delegates to _iter_test_indices(X, y, labels) + """ + for test_index in self._iter_test_indices(X, y, labels): + test_mask = np.zeros(_num_samples(X), dtype=np.bool) + test_mask[test_index] = True + yield test_mask + + def _iter_test_indices(self, X=None, y=None, labels=None): + """Generates integer indices corresponding to test sets.""" + raise NotImplementedError + + @abstractmethod + def get_n_splits(self, X=None, y=None, labels=None): + """Returns the number of splitting iterations in the cross-validator""" + + def __repr__(self): + return _build_repr(self) + + +class LeaveOneOut(BaseCrossValidator): + """Leave-One-Out cross-validator + + Provides train/test indices to split data in train/test sets. Each + sample is used once as a test set (singleton) while the remaining + samples form the training set. + + Note: ``LeaveOneOut()`` is equivalent to ``KFold(n_folds=n)`` and + ``LeavePOut(p=1)`` where ``n`` is the number of samples. + + Due to the high number of test sets (which is the same as the + number of samples) this cross-validation method can be very costly. + For large datasets one should favor :class:`KFold`, :class:`ShuffleSplit` + or :class:`StratifiedKFold`. + + Read more in the :ref:`User Guide `. + + Examples + -------- + >>> from sklearn.model_selection import LeaveOneOut + >>> X = np.array([[1, 2], [3, 4]]) + >>> y = np.array([1, 2]) + >>> loo = LeaveOneOut() + >>> loo.get_n_splits(X) + 2 + >>> print(loo) + LeaveOneOut() + >>> for train_index, test_index in loo.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... print(X_train, X_test, y_train, y_test) + TRAIN: [1] TEST: [0] + [[3 4]] [[1 2]] [2] [1] + TRAIN: [0] TEST: [1] + [[1 2]] [[3 4]] [1] [2] + + See also + -------- + LeaveOneLabelOut + For splitting the data according to explicit, domain-specific + stratification of the dataset. + + LabelKFold: K-fold iterator variant with non-overlapping labels. + """ + + def _iter_test_indices(self, X, y=None, labels=None): + return range(_num_samples(X)) + + def get_n_splits(self, X, y=None, labels=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples + and n_features is the number of features. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + if X is None: + raise ValueError("The X parameter should not be None") + return _num_samples(X) + + +class LeavePOut(BaseCrossValidator): + """Leave-P-Out cross-validator + + Provides train/test indices to split data in train/test sets. This results + in testing on all distinct samples of size p, while the remaining n - p + samples form the training set in each iteration. + + Note: ``LeavePOut(p)`` is NOT equivalent to + ``KFold(n_folds=n_samples // p)`` which creates non-overlapping test sets. + + Due to the high number of iterations which grows combinatorically with the + number of samples this cross-validation method can be very costly. For + large datasets one should favor :class:`KFold`, :class:`StratifiedKFold` + or :class:`ShuffleSplit`. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + p : int + Size of the test sets. + + Examples + -------- + >>> from sklearn.model_selection import LeavePOut + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> y = np.array([1, 2, 3, 4]) + >>> lpo = LeavePOut(2) + >>> lpo.get_n_splits(X) + 6 + >>> print(lpo) + LeavePOut(p=2) + >>> for train_index, test_index in lpo.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [2 3] TEST: [0 1] + TRAIN: [1 3] TEST: [0 2] + TRAIN: [1 2] TEST: [0 3] + TRAIN: [0 3] TEST: [1 2] + TRAIN: [0 2] TEST: [1 3] + TRAIN: [0 1] TEST: [2 3] + """ + + def __init__(self, p): + self.p = p + + def _iter_test_indices(self, X, y=None, labels=None): + for combination in combinations(range(_num_samples(X)), self.p): + yield np.array(combination) + + def get_n_splits(self, X, y=None, labels=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples + and n_features is the number of features. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + """ + if X is None: + raise ValueError("The X parameter should not be None") + return int(comb(_num_samples(X), self.p, exact=True)) + + +class _BaseKFold(with_metaclass(ABCMeta, BaseCrossValidator)): + """Base class for KFold and StratifiedKFold""" + + @abstractmethod + def __init__(self, n_folds, shuffle, random_state): + if not isinstance(n_folds, numbers.Integral): + raise ValueError('The number of folds must be of Integral type. ' + '%s of type %s was passed.' + % (n_folds, type(n_folds))) + n_folds = int(n_folds) + + if n_folds <= 1: + raise ValueError( + "k-fold cross-validation requires at least one" + " train/test split by setting n_folds=2 or more," + " got n_folds={0}.".format(n_folds)) + + if not isinstance(shuffle, bool): + raise TypeError("shuffle must be True or False;" + " got {0}".format(shuffle)) + + self.n_folds = n_folds + self.shuffle = shuffle + self.random_state = random_state + + def split(self, X, y=None, labels=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples + and n_features is the number of features. + + y : array-like, shape (n_samples,), optional + The target variable for supervised learning problems. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + X, y, labels = indexable(X, y, labels) + n_samples = _num_samples(X) + if self.n_folds > n_samples: + raise ValueError( + ("Cannot have number of folds n_folds={0} greater" + " than the number of samples: {1}.").format(self.n_folds, + n_samples)) + + for train, test in super(_BaseKFold, self).split(X, y, labels): + yield train, test + + def get_n_splits(self, X=None, y=None, labels=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + return self.n_folds + + +class KFold(_BaseKFold): + """K-Folds cross-validator + + Provides train/test indices to split data in train/test sets. Split + dataset into k consecutive folds (without shuffling by default). + + Each fold is then used a validation set once while the k - 1 remaining + fold form the training set. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_folds : int, default=3 + Number of folds. Must be at least 2. + + shuffle : boolean, optional + Whether to shuffle the data before splitting into batches. + + random_state : None, int or RandomState + When shuffle=True, pseudo-random number generator state used for + shuffling. If None, use default numpy RNG for shuffling. + + Examples + -------- + >>> from sklearn.model_selection import KFold + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([1, 2, 3, 4]) + >>> kf = KFold(n_folds=2) + >>> kf.get_n_splits(X) + 2 + >>> print(kf) # doctest: +NORMALIZE_WHITESPACE + KFold(n_folds=2, random_state=None, shuffle=False) + >>> for train_index, test_index in kf.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [2 3] TEST: [0 1] + TRAIN: [0 1] TEST: [2 3] + + Notes + ----- + The first ``n_samples % n_folds`` folds have size + ``n_samples // n_folds + 1``, other folds have size + ``n_samples // n_folds``, where ``n_samples`` is the number of samples. + + See also + -------- + StratifiedKFold + For taking label information into account to avoid building folds with + imbalanced class distributions (for binary or multiclass + classification tasks). + + LabelKFold: K-fold iterator variant with non-overlapping labels. + """ + + def __init__(self, n_folds=3, shuffle=False, + random_state=None): + super(KFold, self).__init__(n_folds, shuffle, random_state) + self.shuffle = shuffle + + def _iter_test_indices(self, X, y=None, labels=None): + n_samples = _num_samples(X) + indices = np.arange(n_samples) + if self.shuffle: + check_random_state(self.random_state).shuffle(indices) + + n_folds = self.n_folds + fold_sizes = (n_samples // n_folds) * np.ones(n_folds, dtype=np.int) + fold_sizes[:n_samples % n_folds] += 1 + current = 0 + for fold_size in fold_sizes: + start, stop = current, current + fold_size + yield indices[start:stop] + current = stop + + +class LabelKFold(_BaseKFold): + """K-fold iterator variant with non-overlapping labels. + + The same label will not appear in two different folds (the number of + distinct labels has to be at least equal to the number of folds). + + The folds are approximately balanced in the sense that the number of + distinct labels is approximately the same in each fold. + + Parameters + ---------- + n_folds : int, default=3 + Number of folds. Must be at least 2. + + Examples + -------- + >>> from sklearn.model_selection import LabelKFold + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> y = np.array([1, 2, 3, 4]) + >>> labels = np.array([0, 0, 2, 2]) + >>> label_kfold = LabelKFold(n_folds=2) + >>> label_kfold.get_n_splits(X, y, labels) + 2 + >>> print(label_kfold) + LabelKFold(n_folds=2) + >>> for train_index, test_index in label_kfold.split(X, y, labels): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... print(X_train, X_test, y_train, y_test) + ... + TRAIN: [0 1] TEST: [2 3] + [[1 2] + [3 4]] [[5 6] + [7 8]] [1 2] [3 4] + TRAIN: [2 3] TEST: [0 1] + [[5 6] + [7 8]] [[1 2] + [3 4]] [3 4] [1 2] + + See also + -------- + LeaveOneLabelOut + For splitting the data according to explicit domain-specific + stratification of the dataset. + """ + def __init__(self, n_folds=3): + super(LabelKFold, self).__init__(n_folds, shuffle=False, + random_state=None) + + def _iter_test_indices(self, X, y, labels): + if labels is None: + raise ValueError("The labels parameter should not be None") + + unique_labels, labels = np.unique(labels, return_inverse=True) + n_labels = len(unique_labels) + + if self.n_folds > n_labels: + raise ValueError("Cannot have number of folds n_folds=%d greater" + " than the number of labels: %d." + % (self.n_folds, n_labels)) + + # Weight labels by their number of occurences + n_samples_per_label = np.bincount(labels) + + # Distribute the most frequent labels first + indices = np.argsort(n_samples_per_label)[::-1] + n_samples_per_label = n_samples_per_label[indices] + + # Total weight of each fold + n_samples_per_fold = np.zeros(self.n_folds) + + # Mapping from label index to fold index + label_to_fold = np.zeros(len(unique_labels)) + + # Distribute samples by adding the largest weight to the lightest fold + for label_index, weight in enumerate(n_samples_per_label): + lightest_fold = np.argmin(n_samples_per_fold) + n_samples_per_fold[lightest_fold] += weight + label_to_fold[indices[label_index]] = lightest_fold + + indices = label_to_fold[labels] + + for f in range(self.n_folds): + yield np.where(indices == f)[0] + + +class StratifiedKFold(_BaseKFold): + """Stratified K-Folds cross-validator + + Provides train/test indices to split data in train/test sets. + + This cross-validation object is a variation of KFold that returns + stratified folds. The folds are made by preserving the percentage of + samples for each class. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_folds : int, default=3 + Number of folds. Must be at least 2. + + shuffle : boolean, optional + Whether to shuffle each stratification of the data before splitting + into batches. + + random_state : None, int or RandomState + When shuffle=True, pseudo-random number generator state used for + shuffling. If None, use default numpy RNG for shuffling. + + Examples + -------- + >>> from sklearn.model_selection import StratifiedKFold + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([0, 0, 1, 1]) + >>> skf = StratifiedKFold(n_folds=2) + >>> skf.get_n_splits(X, y) + 2 + >>> print(skf) # doctest: +NORMALIZE_WHITESPACE + StratifiedKFold(n_folds=2, random_state=None, shuffle=False) + >>> for train_index, test_index in skf.split(X, y): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [1 3] TEST: [0 2] + TRAIN: [0 2] TEST: [1 3] + + Notes + ----- + All the folds have size ``trunc(n_samples / n_folds)``, the last one has + the complementary. + + """ + + def __init__(self, n_folds=3, shuffle=False, random_state=None): + super(StratifiedKFold, self).__init__(n_folds, shuffle, random_state) + self.shuffle = shuffle + + def _make_test_folds(self, X, y=None, labels=None): + if self.shuffle: + rng = check_random_state(self.random_state) + else: + rng = self.random_state + y = np.asarray(y) + n_samples = y.shape[0] + unique_y, y_inversed = np.unique(y, return_inverse=True) + y_counts = bincount(y_inversed) + min_labels = np.min(y_counts) + if self.n_folds > min_labels: + warnings.warn(("The least populated class in y has only %d" + " members, which is too few. The minimum" + " number of labels for any class cannot" + " be less than n_folds=%d." + % (min_labels, self.n_folds)), Warning) + + # pre-assign each sample to a test fold index using individual KFold + # splitting strategies for each class so as to respect the balance of + # classes + # NOTE: Passing the data corresponding to ith class say X[y==class_i] + # will break when the data is not 100% stratifiable for all classes. + # So we pass np.zeroes(max(c, n_folds)) as data to the KFold + per_cls_cvs = [ + KFold(self.n_folds, shuffle=self.shuffle, + random_state=rng).split(np.zeros(max(count, self.n_folds))) + for count in y_counts] + + test_folds = np.zeros(n_samples, dtype=np.int) + for test_fold_indices, per_cls_splits in enumerate(zip(*per_cls_cvs)): + for cls, (_, test_split) in zip(unique_y, per_cls_splits): + cls_test_folds = test_folds[y == cls] + # the test split can be too big because we used + # KFold(...).split(X[:max(c, n_folds)]) when data is not 100% + # stratifiable for all the classes + # (we use a warning instead of raising an exception) + # If this is the case, let's trim it: + test_split = test_split[test_split < len(cls_test_folds)] + cls_test_folds[test_split] = test_fold_indices + test_folds[y == cls] = cls_test_folds + + return test_folds + + def _iter_test_masks(self, X, y=None, labels=None): + test_folds = self._make_test_folds(X, y) + for i in range(self.n_folds): + yield test_folds == i + + +class LeaveOneLabelOut(BaseCrossValidator): + """Leave One Label Out cross-validator + + Provides train/test indices to split data according to a third-party + provided label. This label information can be used to encode arbitrary + domain specific stratifications of the samples as integers. + + For instance the labels could be the year of collection of the samples + and thus allow for cross-validation against time-based splits. + + Read more in the :ref:`User Guide `. + + Examples + -------- + >>> from sklearn.model_selection import LeaveOneLabelOut + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> y = np.array([1, 2, 1, 2]) + >>> labels = np.array([1, 1, 2, 2]) + >>> lol = LeaveOneLabelOut() + >>> lol.get_n_splits(X, y, labels) + 2 + >>> print(lol) + LeaveOneLabelOut() + >>> for train_index, test_index in lol.split(X, y, labels): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... print(X_train, X_test, y_train, y_test) + TRAIN: [2 3] TEST: [0 1] + [[5 6] + [7 8]] [[1 2] + [3 4]] [1 2] [1 2] + TRAIN: [0 1] TEST: [2 3] + [[1 2] + [3 4]] [[5 6] + [7 8]] [1 2] [1 2] + + """ + + def _iter_test_masks(self, X, y, labels): + if labels is None: + raise ValueError("The labels parameter should not be None") + # We make a copy of labels to avoid side-effects during iteration + labels = np.array(labels, copy=True) + unique_labels = np.unique(labels) + for i in unique_labels: + yield labels == i + + def get_n_splits(self, X, y, labels): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + if labels is None: + raise ValueError("The labels parameter should not be None") + return len(np.unique(labels)) + + +class LeavePLabelOut(BaseCrossValidator): + """Leave P Labels Out cross-validator + + Provides train/test indices to split data according to a third-party + provided label. This label information can be used to encode arbitrary + domain specific stratifications of the samples as integers. + + For instance the labels could be the year of collection of the samples + and thus allow for cross-validation against time-based splits. + + The difference between LeavePLabelOut and LeaveOneLabelOut is that + the former builds the test sets with all the samples assigned to + ``p`` different values of the labels while the latter uses samples + all assigned the same labels. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_labels : int + Number of labels (``p``) to leave out in the test split. + + Examples + -------- + >>> from sklearn.model_selection import LeavePLabelOut + >>> X = np.array([[1, 2], [3, 4], [5, 6]]) + >>> y = np.array([1, 2, 1]) + >>> labels = np.array([1, 2, 3]) + >>> lpl = LeavePLabelOut(n_labels=2) + >>> lpl.get_n_splits(X, y, labels) + 3 + >>> print(lpl) + LeavePLabelOut(n_labels=2) + >>> for train_index, test_index in lpl.split(X, y, labels): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... print(X_train, X_test, y_train, y_test) + TRAIN: [2] TEST: [0 1] + [[5 6]] [[1 2] + [3 4]] [1] [1 2] + TRAIN: [1] TEST: [0 2] + [[3 4]] [[1 2] + [5 6]] [2] [1 1] + TRAIN: [0] TEST: [1 2] + [[1 2]] [[3 4] + [5 6]] [1] [2 1] + + See also + -------- + LabelKFold: K-fold iterator variant with non-overlapping labels. + """ + + def __init__(self, n_labels): + self.n_labels = n_labels + + def _iter_test_masks(self, X, y, labels): + if labels is None: + raise ValueError("The labels parameter should not be None") + labels = np.array(labels, copy=True) + unique_labels = np.unique(labels) + combi = combinations(range(len(unique_labels)), self.n_labels) + for indices in combi: + test_index = np.zeros(_num_samples(X), dtype=np.bool) + for l in unique_labels[np.array(indices)]: + test_index[labels == l] = True + yield test_index + + def get_n_splits(self, X, y, labels): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + if labels is None: + raise ValueError("The labels parameter should not be None") + return int(comb(len(np.unique(labels)), self.n_labels, exact=True)) + + +class BaseShuffleSplit(with_metaclass(ABCMeta)): + """Base class for ShuffleSplit and StratifiedShuffleSplit""" + + def __init__(self, n_iter=10, test_size=0.1, train_size=None, + random_state=None): + _validate_shuffle_split_init(test_size, train_size) + self.n_iter = n_iter + self.test_size = test_size + self.train_size = train_size + self.random_state = random_state + + def split(self, X, y=None, labels=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data, where n_samples is the number of samples + and n_features is the number of features. + + y : array-like, shape (n_samples,) + The target variable for supervised learning problems. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + Returns + ------- + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + X, y, labels = indexable(X, y, labels) + for train, test in self._iter_indices(X, y, labels): + yield train, test + + @abstractmethod + def _iter_indices(self, X, y=None, labels=None): + """Generate (train, test) indices""" + + def get_n_splits(self, X=None, y=None, labels=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + return self.n_iter + + def __repr__(self): + return _build_repr(self) + + +class ShuffleSplit(BaseShuffleSplit): + """Random permutation cross-validator + + Yields indices to split data into training and test sets. + + Note: contrary to other cross-validation strategies, random splits + do not guarantee that all folds will be different, although this is + still very likely for sizeable datasets. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_iter : int (default 10) + Number of re-shuffling & splitting iterations. + + test_size : float, int, or None, default 0.1 + If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the test split. If + int, represents the absolute number of test samples. If None, + the value is automatically set to the complement of the train size. + + train_size : float, int, or None (default is None) + If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the train split. If + int, represents the absolute number of train samples. If None, + the value is automatically set to the complement of the test size. + + random_state : int or RandomState + Pseudo-random number generator state used for random sampling. + + Examples + -------- + >>> from sklearn.model_selection import ShuffleSplit + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> y = np.array([1, 2, 1, 2]) + >>> rs = ShuffleSplit(n_iter=3, test_size=.25, random_state=0) + >>> rs.get_n_splits(X) + 3 + >>> print(rs) + ShuffleSplit(n_iter=3, random_state=0, test_size=0.25, train_size=None) + >>> for train_index, test_index in rs.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... # doctest: +ELLIPSIS + TRAIN: [3 1 0] TEST: [2] + TRAIN: [2 1 3] TEST: [0] + TRAIN: [0 2 1] TEST: [3] + >>> rs = ShuffleSplit(n_iter=3, train_size=0.5, test_size=.25, + ... random_state=0) + >>> for train_index, test_index in rs.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... # doctest: +ELLIPSIS + TRAIN: [3 1] TEST: [2] + TRAIN: [2 1] TEST: [0] + TRAIN: [0 2] TEST: [3] + """ + + def _iter_indices(self, X, y=None, labels=None): + n_samples = _num_samples(X) + n_train, n_test = _validate_shuffle_split(n_samples, self.test_size, + self.train_size) + rng = check_random_state(self.random_state) + for i in range(self.n_iter): + # random partition + permutation = rng.permutation(n_samples) + ind_test = permutation[:n_test] + ind_train = permutation[n_test:(n_test + n_train)] + yield ind_train, ind_test + + +class LabelShuffleSplit(ShuffleSplit): + '''Shuffle-Labels-Out cross-validation iterator + + Provides randomized train/test indices to split data according to a + third-party provided label. This label information can be used to encode + arbitrary domain specific stratifications of the samples as integers. + + For instance the labels could be the year of collection of the samples + and thus allow for cross-validation against time-based splits. + + The difference between LeavePLabelOut and LabelShuffleSplit is that + the former generates splits using all subsets of size ``p`` unique labels, + whereas LabelShuffleSplit generates a user-determined number of random + test splits, each with a user-determined fraction of unique labels. + + For example, a less computationally intensive alternative to + ``LeavePLabelOut(p=10)`` would be + ``LabelShuffleSplit(test_size=10, n_iter=100)``. + + Note: The parameters ``test_size`` and ``train_size`` refer to labels, and + not to samples, as in ShuffleSplit. + + + Parameters + ---------- + n_iter : int (default 5) + Number of re-shuffling & splitting iterations. + + test_size : float (default 0.2), int, or None + If float, should be between 0.0 and 1.0 and represent the + proportion of the labels to include in the test split. If + int, represents the absolute number of test labels. If None, + the value is automatically set to the complement of the train size. + + train_size : float, int, or None (default is None) + If float, should be between 0.0 and 1.0 and represent the + proportion of the labels to include in the train split. If + int, represents the absolute number of train labels. If None, + the value is automatically set to the complement of the test size. + + random_state : int or RandomState + Pseudo-random number generator state used for random sampling. + ''' + + def __init__(self, n_iter=5, test_size=0.2, train_size=None, + random_state=None): + super(LabelShuffleSplit, self).__init__( + n_iter=n_iter, + test_size=test_size, + train_size=train_size, + random_state=random_state) + + def _iter_indices(self, X, y, labels): + if labels is None: + raise ValueError("The labels parameter should not be None") + classes, label_indices = np.unique(labels, return_inverse=True) + for label_train, label_test in super( + LabelShuffleSplit, self)._iter_indices(X=classes): + # these are the indices of classes in the partition + # invert them into data indices + + train = np.flatnonzero(np.in1d(label_indices, label_train)) + test = np.flatnonzero(np.in1d(label_indices, label_test)) + + yield train, test + + +class StratifiedShuffleSplit(BaseShuffleSplit): + """Stratified ShuffleSplit cross-validator + + Provides train/test indices to split data in train/test sets. + + This cross-validation object is a merge of StratifiedKFold and + ShuffleSplit, which returns stratified randomized folds. The folds + are made by preserving the percentage of samples for each class. + + Note: like the ShuffleSplit strategy, stratified random splits + do not guarantee that all folds will be different, although this is + still very likely for sizeable datasets. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_iter : int (default 10) + Number of re-shuffling & splitting iterations. + + test_size : float (default 0.1), int, or None + If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the test split. If + int, represents the absolute number of test samples. If None, + the value is automatically set to the complement of the train size. + + train_size : float, int, or None (default is None) + If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the train split. If + int, represents the absolute number of train samples. If None, + the value is automatically set to the complement of the test size. + + random_state : int or RandomState + Pseudo-random number generator state used for random sampling. + + Examples + -------- + >>> from sklearn.model_selection import StratifiedShuffleSplit + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([0, 0, 1, 1]) + >>> sss = StratifiedShuffleSplit(n_iter=3, test_size=0.5, random_state=0) + >>> sss.get_n_splits(X, y) + 3 + >>> print(sss) # doctest: +ELLIPSIS + StratifiedShuffleSplit(n_iter=3, random_state=0, ...) + >>> for train_index, test_index in sss.split(X, y): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [1 2] TEST: [3 0] + TRAIN: [0 2] TEST: [1 3] + TRAIN: [0 2] TEST: [3 1] + """ + + def __init__(self, n_iter=10, test_size=0.1, train_size=None, + random_state=None): + super(StratifiedShuffleSplit, self).__init__( + n_iter, test_size, train_size, random_state) + + def _iter_indices(self, X, y, labels=None): + n_samples = _num_samples(X) + n_train, n_test = _validate_shuffle_split(n_samples, self.test_size, + self.train_size) + classes, y_indices = np.unique(y, return_inverse=True) + n_classes = classes.shape[0] + + class_counts = bincount(y_indices) + if np.min(class_counts) < 2: + raise ValueError("The least populated class in y has only 1" + " member, which is too few. The minimum" + " number of labels for any class cannot" + " be less than 2.") + + if n_train < n_classes: + raise ValueError('The train_size = %d should be greater or ' + 'equal to the number of classes = %d' % + (n_train, n_classes)) + if n_test < n_classes: + raise ValueError('The test_size = %d should be greater or ' + 'equal to the number of classes = %d' % + (n_test, n_classes)) + + rng = check_random_state(self.random_state) + p_i = class_counts / float(n_samples) + n_i = np.round(n_train * p_i).astype(int) + t_i = np.minimum(class_counts - n_i, + np.round(n_test * p_i).astype(int)) + + for _ in range(self.n_iter): + train = [] + test = [] + + for i, class_i in enumerate(classes): + permutation = rng.permutation(class_counts[i]) + perm_indices_class_i = np.where((y == class_i))[0][permutation] + + train.extend(perm_indices_class_i[:n_i[i]]) + test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]]) + + # Because of rounding issues (as n_train and n_test are not + # dividers of the number of elements per class), we may end + # up here with less samples in train and test than asked for. + if len(train) < n_train or len(test) < n_test: + # We complete by affecting randomly the missing indexes + missing_indices = np.where(bincount(train + test, + minlength=len(y)) == 0)[0] + missing_indices = rng.permutation(missing_indices) + train.extend(missing_indices[:(n_train - len(train))]) + test.extend(missing_indices[-(n_test - len(test)):]) + + train = rng.permutation(train) + test = rng.permutation(test) + + yield train, test + + +def _validate_shuffle_split_init(test_size, train_size): + """Validation helper to check the test_size and train_size at init + + NOTE This does not take into account the number of samples which is known + only at split + """ + if test_size is None and train_size is None: + raise ValueError('test_size and train_size can not both be None') + + if test_size is not None: + if np.asarray(test_size).dtype.kind == 'f': + if test_size >= 1.: + raise ValueError( + 'test_size=%f should be smaller ' + 'than 1.0 or be an integer' % test_size) + elif np.asarray(test_size).dtype.kind != 'i': + # int values are checked during split based on the input + raise ValueError("Invalid value for test_size: %r" % test_size) + + if train_size is not None: + if np.asarray(train_size).dtype.kind == 'f': + if train_size >= 1.: + raise ValueError("train_size=%f should be smaller " + "than 1.0 or be an integer" % train_size) + elif (np.asarray(test_size).dtype.kind == 'f' and + (train_size + test_size) > 1.): + raise ValueError('The sum of test_size and train_size = %f, ' + 'should be smaller than 1.0. Reduce ' + 'test_size and/or train_size.' % + (train_size + test_size)) + elif np.asarray(train_size).dtype.kind != 'i': + # int values are checked during split based on the input + raise ValueError("Invalid value for train_size: %r" % train_size) + + +def _validate_shuffle_split(n_samples, test_size, train_size): + """ + Validation helper to check if the test/test sizes are meaningful wrt to the + size of the data (n_samples) + """ + if (test_size is not None and np.asarray(test_size).dtype.kind == 'i' + and test_size >= n_samples): + raise ValueError('test_size=%d should be smaller than the number of ' + 'samples %d' % (test_size, n_samples)) + + if (train_size is not None and np.asarray(train_size).dtype.kind == 'i' + and train_size >= n_samples): + raise ValueError("train_size=%d should be smaller than the number of" + " samples %d" % (train_size, n_samples)) + + if np.asarray(test_size).dtype.kind == 'f': + n_test = ceil(test_size * n_samples) + elif np.asarray(test_size).dtype.kind == 'i': + n_test = float(test_size) + + if train_size is None: + n_train = n_samples - n_test + elif np.asarray(train_size).dtype.kind == 'f': + n_train = floor(train_size * n_samples) + else: + n_train = float(train_size) + + if test_size is None: + n_test = n_samples - n_train + + if n_train + n_test > n_samples: + raise ValueError('The sum of train_size and test_size = %d, ' + 'should be smaller than the number of ' + 'samples %d. Reduce test_size and/or ' + 'train_size.' % (n_train + n_test, n_samples)) + + return int(n_train), int(n_test) + + +class PredefinedSplit(BaseCrossValidator): + """Predefined split cross-validator + + Splits the data into training/test set folds according to a predefined + scheme. Each sample can be assigned to at most one test set fold, as + specified by the user through the ``test_fold`` parameter. + + Read more in the :ref:`User Guide `. + + Examples + -------- + >>> from sklearn.model_selection import PredefinedSplit + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) + >>> y = np.array([0, 0, 1, 1]) + >>> test_fold = [0, 1, -1, 1] + >>> ps = PredefinedSplit(test_fold) + >>> ps.get_n_splits() + 2 + >>> print(ps) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS + PredefinedSplit(test_fold=array([ 0, 1, -1, 1])) + >>> for train_index, test_index in ps.split(): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [1 2 3] TEST: [0] + TRAIN: [0 2] TEST: [1 3] + """ + + def __init__(self, test_fold): + self.test_fold = np.array(test_fold, dtype=np.int) + self.test_fold = column_or_1d(self.test_fold) + self.unique_folds = np.unique(self.test_fold) + self.unique_folds = self.unique_folds[self.unique_folds != -1] + + def split(self, X=None, y=None, labels=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + + Returns + ------- + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + ind = np.arange(len(self.test_fold)) + for test_index in self._iter_test_masks(): + train_index = ind[np.logical_not(test_index)] + test_index = ind[test_index] + yield train_index, test_index + + def _iter_test_masks(self): + """Generates boolean masks corresponding to test sets.""" + for f in self.unique_folds: + test_index = np.where(self.test_fold == f)[0] + test_mask = np.zeros(len(self.test_fold), dtype=np.bool) + test_mask[test_index] = True + yield test_mask + + def get_n_splits(self, X=None, y=None, labels=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + return len(self.unique_folds) + + +class _CVIterableWrapper(BaseCrossValidator): + """Wrapper class for old style cv objects and iterables.""" + def __init__(self, cv): + self.cv = cv + + def get_n_splits(self, X=None, y=None, labels=None): + """Returns the number of splitting iterations in the cross-validator + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + return len(self.cv) # Both iterables and old-cv objects support len + + def split(self, X=None, y=None, labels=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + + y : object + Always ignored, exists for compatibility. + + labels : object + Always ignored, exists for compatibility. + + Returns + ------- + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + for train, test in self.cv: + yield train, test + + +def check_cv(cv=3, y=None, classifier=False): + """Input checker utility for building a cross-validator + + Parameters + ---------- + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross-validation, + - integer, to specify the number of folds. + - An object to be used as a cross-validation generator. + - An iterable yielding train/test splits. + + For integer/None inputs, if ``y`` is binary or multiclass, + :class:`StratifiedKFold` used. If classifier is False or if ``y`` is + neither binary nor multiclass, :class:`KFold` is used. + + Refer :ref:`User Guide ` for the various + cross-validation strategies that can be used here. + + y : array-like, optional + The target variable for supervised learning problems. + + classifier : boolean, optional, default False + Whether the task is a classification task, in which case + stratified KFold will be used. + + Returns + ------- + checked_cv : a cross-validator instance. + The return value is a cross-validator which generates the train/test + splits via the ``split`` method. + """ + if cv is None: + cv = 3 + + if isinstance(cv, numbers.Integral): + if (classifier and (y is not None) and + (type_of_target(y) in ('binary', 'multiclass'))): + return StratifiedKFold(cv) + else: + return KFold(cv) + + if not hasattr(cv, 'split') or isinstance(cv, str): + if not isinstance(cv, Iterable) or isinstance(cv, str): + raise ValueError("Expected cv as an integer, cross-validation " + "object (from sklearn.model_selection) " + "or and iterable. Got %s." % cv) + return _CVIterableWrapper(cv) + + return cv # New style cv objects are passed without any modification + + +def train_test_split(*arrays, **options): + """Split arrays or matrices into random train and test subsets + + Quick utility that wraps input validation and + ``next(ShuffleSplit().split(X, y))`` and application to input data + into a single call for splitting (and optionally subsampling) data in a + oneliner. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + *arrays : sequence of arrays or scipy.sparse matrices with same shape[0] + Python lists or tuples occurring in arrays are converted to 1D numpy + arrays. + + test_size : float, int, or None (default is None) + If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the test split. If + int, represents the absolute number of test samples. If None, + the value is automatically set to the complement of the train size. + If train size is also None, test size is set to 0.25. + + train_size : float, int, or None (default is None) + If float, should be between 0.0 and 1.0 and represent the + proportion of the dataset to include in the train split. If + int, represents the absolute number of train samples. If None, + the value is automatically set to the complement of the test size. + + random_state : int or RandomState + Pseudo-random number generator state used for random sampling. + + stratify : array-like or None (default is None) + If not None, data is split in a stratified fashion, using this as + the labels array. + + Returns + ------- + splitting : list of arrays, length=2 * len(arrays) + List containing train-test split of input array. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.model_selection import train_test_split + >>> X, y = np.arange(10).reshape((5, 2)), range(5) + >>> X + array([[0, 1], + [2, 3], + [4, 5], + [6, 7], + [8, 9]]) + >>> list(y) + [0, 1, 2, 3, 4] + + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, test_size=0.33, random_state=42) + ... + >>> X_train + array([[4, 5], + [0, 1], + [6, 7]]) + >>> y_train + [2, 0, 3] + >>> X_test + array([[2, 3], + [8, 9]]) + >>> y_test + [1, 4] + + """ + n_arrays = len(arrays) + if n_arrays == 0: + raise ValueError("At least one array required as input") + test_size = options.pop('test_size', None) + train_size = options.pop('train_size', None) + random_state = options.pop('random_state', None) + stratify = options.pop('stratify', None) + + if options: + raise TypeError("Invalid parameters passed: %s" % str(options)) + + if test_size is None and train_size is None: + test_size = 0.25 + + arrays = indexable(*arrays) + + if stratify is not None: + CVClass = StratifiedShuffleSplit + else: + CVClass = ShuffleSplit + + cv = CVClass(test_size=test_size, + train_size=train_size, + random_state=random_state) + + train, test = next(cv.split(X=arrays[0], y=stratify)) + return list(chain.from_iterable((safe_indexing(a, train), + safe_indexing(a, test)) for a in arrays)) + + +train_test_split.__test__ = False # to avoid a pb with nosetests + + +def _safe_split(estimator, X, y, indices, train_indices=None): + """Create subset of dataset and properly handle kernels.""" + if (hasattr(estimator, 'kernel') and callable(estimator.kernel) and + not isinstance(estimator.kernel, GPKernel)): + # cannot compute the kernel values with custom function + raise ValueError("Cannot use a custom kernel function. " + "Precompute the kernel matrix instead.") + + if not hasattr(X, "shape"): + if getattr(estimator, "_pairwise", False): + raise ValueError("Precomputed kernels or affinity matrices have " + "to be passed as arrays or sparse matrices.") + X_subset = [X[index] for index in indices] + else: + if getattr(estimator, "_pairwise", False): + # X is a precomputed square kernel matrix + if X.shape[0] != X.shape[1]: + raise ValueError("X should be a square kernel matrix") + if train_indices is None: + X_subset = X[np.ix_(indices, indices)] + else: + X_subset = X[np.ix_(indices, train_indices)] + else: + X_subset = safe_indexing(X, indices) + + if y is not None: + y_subset = safe_indexing(y, indices) + else: + y_subset = None + + return X_subset, y_subset + + +def _build_repr(self): + # XXX This is copied from BaseEstimator's get_params + cls = self.__class__ + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + # Ignore varargs, kw and default values and pop self + if init is object.__init__: + # No explicit constructor to introspect + args = [] + else: + args = sorted(inspect.getargspec(init)[0]) + if 'self' in args: + args.remove('self') + class_name = self.__class__.__name__ + params = dict() + for key in args: + # We need deprecation warnings to always be on in order to + # catch deprecated param values. + # This is set in utils/__init__.py but it gets overwritten + # when running under python3 somehow. + warnings.simplefilter("always", DeprecationWarning) + try: + with warnings.catch_warnings(record=True) as w: + value = getattr(self, key, None) + if len(w) and w[0].category == DeprecationWarning: + # if the parameter is deprecated, don't show it + continue + finally: + warnings.filters.pop(0) + params[key] = value + + return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name))) diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py new file mode 100644 index 0000000000000..05673b954e8a6 --- /dev/null +++ b/sklearn/model_selection/_validation.py @@ -0,0 +1,928 @@ +""" +The :mod:`sklearn.model_selection._validation` module includes classes and +functions to validate the model. +""" + +# Author: Alexandre Gramfort , +# Gael Varoquaux , +# Olivier Grisel +# License: BSD 3 clause + + +from __future__ import print_function +from __future__ import division + +import warnings +import numbers +import time + +import numpy as np +import scipy.sparse as sp + +from ..base import is_classifier, clone +from ..utils import indexable, check_random_state, safe_indexing +from ..utils.fixes import astype +from ..utils.validation import _is_arraylike, _num_samples +from ..externals.joblib import Parallel, delayed, logger +from ..metrics.scorer import check_scoring +from ..exceptions import FitFailedWarning + +from ._split import KFold +from ._split import LabelKFold +from ._split import LeaveOneLabelOut +from ._split import LeaveOneOut +from ._split import LeavePLabelOut +from ._split import LeavePOut +from ._split import ShuffleSplit +from ._split import LabelShuffleSplit +from ._split import StratifiedKFold +from ._split import StratifiedShuffleSplit +from ._split import PredefinedSplit +from ._split import check_cv, _safe_split + +__all__ = ['cross_val_score', 'cross_val_predict', 'permutation_test_score', + 'learning_curve', 'validation_curve'] + +ALL_CVS = {'KFold': KFold, + 'LabelKFold': LabelKFold, + 'LeaveOneLabelOut': LeaveOneLabelOut, + 'LeaveOneOut': LeaveOneOut, + 'LeavePLabelOut': LeavePLabelOut, + 'LeavePOut': LeavePOut, + 'ShuffleSplit': ShuffleSplit, + 'LabelShuffleSplit': LabelShuffleSplit, + 'StratifiedKFold': StratifiedKFold, + 'StratifiedShuffleSplit': StratifiedShuffleSplit, + 'PredefinedSplit': PredefinedSplit} + +LABEL_CVS = {'LabelKFold': LabelKFold, + 'LeaveOneLabelOut': LeaveOneLabelOut, + 'LeavePLabelOut': LeavePLabelOut, + 'LabelShuffleSplit': LabelShuffleSplit} + + +def cross_val_score(estimator, X, y=None, labels=None, scoring=None, cv=None, + n_jobs=1, verbose=0, fit_params=None, + pre_dispatch='2*n_jobs'): + """Evaluate a score by cross-validation + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like + The data to fit. Can be, for example a list, or an array at least 2d. + + y : array-like, optional, default: None + The target variable to try to predict in the case of + supervised learning. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + scoring : string, callable or None, optional, default: None + A string (see model evaluation documentation) or + a scorer callable object / function with signature + ``scorer(estimator, X, y)``. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, ``StratifiedKFold`` is used for classification + tasks, when ``y`` is binary or multiclass. + + See the :mod:`sklearn.model_selection` module for the list of + cross-validation strategies that can be used here. + + Also refer :ref:`cross-validation documentation ` + + n_jobs : integer, optional + The number of CPUs to use to do the computation. -1 means + 'all CPUs'. + + verbose : integer, optional + The verbosity level. + + fit_params : dict, optional + Parameters to pass to the fit method of the estimator. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediately + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + Returns + ------- + scores : array of float, shape=(len(list(cv)),) + Array of scores of the estimator for each run of the cross validation. + """ + X, y, labels = indexable(X, y, labels) + + cv = check_cv(cv, y, classifier=is_classifier(estimator)) + scorer = check_scoring(estimator, scoring=scoring) + # We clone the estimator to make sure that all the folds are + # independent, and that it is pickle-able. + parallel = Parallel(n_jobs=n_jobs, verbose=verbose, + pre_dispatch=pre_dispatch) + scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, + train, test, verbose, None, + fit_params) + for train, test in cv.split(X, y, labels)) + return np.array(scores)[:, 0] + + +def _fit_and_score(estimator, X, y, scorer, train, test, verbose, + parameters, fit_params, return_train_score=False, + return_parameters=False, error_score='raise'): + """Fit estimator and compute scores for a given dataset split. + + Parameters + ---------- + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like of shape at least 2D + The data to fit. + + y : array-like, optional, default: None + The target variable to try to predict in the case of + supervised learning. + + scorer : callable + A scorer callable object / function with signature + ``scorer(estimator, X, y)``. + + train : array-like, shape (n_train_samples,) + Indices of training samples. + + test : array-like, shape (n_test_samples,) + Indices of test samples. + + verbose : integer + The verbosity level. + + error_score : 'raise' (default) or numeric + Value to assign to the score if an error occurs in estimator fitting. + If set to 'raise', the error is raised. If a numeric value is given, + FitFailedWarning is raised. This parameter does not affect the refit + step, which will always raise the error. + + parameters : dict or None + Parameters to be set on the estimator. + + fit_params : dict or None + Parameters that will be passed to ``estimator.fit``. + + return_train_score : boolean, optional, default: False + Compute and return score on training set. + + return_parameters : boolean, optional, default: False + Return parameters that has been used for the estimator. + + Returns + ------- + train_score : float, optional + Score on training set, returned only if `return_train_score` is `True`. + + test_score : float + Score on test set. + + n_test_samples : int + Number of test samples. + + scoring_time : float + Time spent for fitting and scoring in seconds. + + parameters : dict or None, optional + The parameters that have been evaluated. + """ + if verbose > 1: + if parameters is None: + msg = "no parameters to be set" + else: + msg = '%s' % (', '.join('%s=%s' % (k, v) + for k, v in parameters.items())) + print("[CV] %s %s" % (msg, (64 - len(msg)) * '.')) + + # Adjust length of sample weights + fit_params = fit_params if fit_params is not None else {} + fit_params = dict([(k, _index_param_value(X, v, train)) + for k, v in fit_params.items()]) + + if parameters is not None: + estimator.set_params(**parameters) + + start_time = time.time() + + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, y_test = _safe_split(estimator, X, y, test, train) + + try: + if y_train is None: + estimator.fit(X_train, **fit_params) + else: + estimator.fit(X_train, y_train, **fit_params) + + except Exception as e: + if error_score == 'raise': + raise + elif isinstance(error_score, numbers.Number): + test_score = error_score + if return_train_score: + train_score = error_score + warnings.warn("Classifier fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%r" % (error_score, e), FitFailedWarning) + else: + raise ValueError("error_score must be the string 'raise' or a" + " numeric value. (Hint: if using 'raise', please" + " make sure that it has been spelled correctly.)") + + else: + test_score = _score(estimator, X_test, y_test, scorer) + if return_train_score: + train_score = _score(estimator, X_train, y_train, scorer) + + scoring_time = time.time() - start_time + + if verbose > 2: + msg += ", score=%f" % test_score + if verbose > 1: + end_msg = "%s -%s" % (msg, logger.short_format_time(scoring_time)) + print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) + + ret = [train_score] if return_train_score else [] + ret.extend([test_score, _num_samples(X_test), scoring_time]) + if return_parameters: + ret.append(parameters) + return ret + + +def _score(estimator, X_test, y_test, scorer): + """Compute the score of an estimator on a given test set.""" + if y_test is None: + score = scorer(estimator, X_test) + else: + score = scorer(estimator, X_test, y_test) + if not isinstance(score, numbers.Number): + raise ValueError("scoring must return a number, got %s (%s) instead." + % (str(score), type(score))) + return score + + +def cross_val_predict(estimator, X, y=None, labels=None, cv=None, n_jobs=1, + verbose=0, fit_params=None, pre_dispatch='2*n_jobs'): + """Generate cross-validated estimates for each input data point + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object implementing 'fit' and 'predict' + The object to use to fit the data. + + X : array-like + The data to fit. Can be, for example a list, or an array at least 2d. + + y : array-like, optional, default: None + The target variable to try to predict in the case of + supervised learning. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, ``StratifiedKFold`` is used for classification + tasks, when ``y`` is binary or multiclass. + + See the :mod:`sklearn.model_selection` module for the list of + cross-validation strategies that can be used here. + + Also refer :ref:`cross-validation documentation <_cross_validation>` + + n_jobs : integer, optional + The number of CPUs to use to do the computation. -1 means + 'all CPUs'. + + verbose : integer, optional + The verbosity level. + + fit_params : dict, optional + Parameters to pass to the fit method of the estimator. + + pre_dispatch : int, or string, optional + Controls the number of jobs that get dispatched during parallel + execution. Reducing this number can be useful to avoid an + explosion of memory consumption when more jobs get dispatched + than CPUs can process. This parameter can be: + + - None, in which case all the jobs are immediately + created and spawned. Use this for lightweight and + fast-running jobs, to avoid delays due to on-demand + spawning of the jobs + + - An int, giving the exact number of total jobs that are + spawned + + - A string, giving an expression as a function of n_jobs, + as in '2*n_jobs' + + Returns + ------- + predictions : ndarray + This is the result of calling 'predict' + """ + X, y, labels = indexable(X, y, labels) + + cv = check_cv(cv, y, classifier=is_classifier(estimator)) + # We clone the estimator to make sure that all the folds are + # independent, and that it is pickle-able. + parallel = Parallel(n_jobs=n_jobs, verbose=verbose, + pre_dispatch=pre_dispatch) + prediction_blocks = parallel(delayed(_fit_and_predict)( + clone(estimator), X, y, train, test, verbose, fit_params) + for train, test in cv.split(X, y, labels)) + + # Concatenate the predictions + predictions = [pred_block_i for pred_block_i, _ in prediction_blocks] + test_indices = np.concatenate([indices_i + for _, indices_i in prediction_blocks]) + + if not _check_is_permutation(test_indices, _num_samples(X)): + raise ValueError('cross_val_predict only works for partitions') + + inv_test_indices = np.empty(len(test_indices), dtype=int) + inv_test_indices[test_indices] = np.arange(len(test_indices)) + + # Check for sparse predictions + if sp.issparse(predictions[0]): + predictions = sp.vstack(predictions, format=predictions[0].format) + else: + predictions = np.concatenate(predictions) + return predictions[inv_test_indices] + + +def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params): + """Fit estimator and predict values for a given dataset split. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object implementing 'fit' and 'predict' + The object to use to fit the data. + + X : array-like of shape at least 2D + The data to fit. + + y : array-like, optional, default: None + The target variable to try to predict in the case of + supervised learning. + + train : array-like, shape (n_train_samples,) + Indices of training samples. + + test : array-like, shape (n_test_samples,) + Indices of test samples. + + verbose : integer + The verbosity level. + + fit_params : dict or None + Parameters that will be passed to ``estimator.fit``. + + Returns + ------- + predictions : sequence + Result of calling 'estimator.predict' + + test : array-like + This is the value of the test parameter + """ + # Adjust length of sample weights + fit_params = fit_params if fit_params is not None else {} + fit_params = dict([(k, _index_param_value(X, v, train)) + for k, v in fit_params.items()]) + + X_train, y_train = _safe_split(estimator, X, y, train) + X_test, _ = _safe_split(estimator, X, y, test, train) + + if y_train is None: + estimator.fit(X_train, **fit_params) + else: + estimator.fit(X_train, y_train, **fit_params) + predictions = estimator.predict(X_test) + return predictions, test + + +def _check_is_permutation(indices, n_samples): + """Check whether indices is a reordering of the array np.arange(n_samples) + + Parameters + ---------- + indices : ndarray + integer array to test + n_samples : int + number of expected elements + + Returns + ------- + is_partition : bool + True iff sorted(locs) is range(n) + """ + if len(indices) != n_samples: + return False + hit = np.zeros(n_samples, bool) + hit[indices] = True + if not np.all(hit): + return False + return True + + +def _index_param_value(X, v, indices): + """Private helper function for parameter value indexing.""" + if not _is_arraylike(v) or _num_samples(v) != _num_samples(X): + # pass through: skip indexing + return v + if sp.issparse(v): + v = v.tocsr() + return safe_indexing(v, indices) + + +def permutation_test_score(estimator, X, y, labels=None, cv=None, + n_permutations=100, n_jobs=1, random_state=0, + verbose=0, scoring=None): + """Evaluate the significance of a cross-validated score with permutations + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object implementing 'fit' + The object to use to fit the data. + + X : array-like of shape at least 2D + The data to fit. + + y : array-like + The target variable to try to predict in the case of + supervised learning. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + scoring : string, callable or None, optional, default: None + A string (see model evaluation documentation) or + a scorer callable object / function with signature + ``scorer(estimator, X, y)``. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, ``StratifiedKFold`` is used for classification + tasks, when ``y`` is binary or multiclass. + + See the :mod:`sklearn.model_selection` module for the list of + cross-validation strategies that can be used here. + + Also refer :ref:`cross-validation documentation <_cross_validation>` + + n_permutations : integer, optional + Number of times to permute ``y``. + + n_jobs : integer, optional + The number of CPUs to use to do the computation. -1 means + 'all CPUs'. + + random_state : RandomState or an int seed (0 by default) + A random number generator instance to define the state of the + random permutations generator. + + verbose : integer, optional + The verbosity level. + + Returns + ------- + score : float + The true score without permuting targets. + + permutation_scores : array, shape (n_permutations,) + The scores obtained for each permutations. + + pvalue : float + The returned value equals p-value if `scoring` returns bigger + numbers for better scores (e.g., accuracy_score). If `scoring` is + rather a loss function (i.e. when lower is better such as with + `mean_squared_error`) then this is actually the complement of the + p-value: 1 - p-value. + + Notes + ----- + This function implements Test 1 in: + + Ojala and Garriga. Permutation Tests for Studying Classifier + Performance. The Journal of Machine Learning Research (2010) + vol. 11 + + """ + X, y, labels = indexable(X, y, labels) + + cv = check_cv(cv, y, classifier=is_classifier(estimator)) + scorer = check_scoring(estimator, scoring=scoring) + random_state = check_random_state(random_state) + + # We clone the estimator to make sure that all the folds are + # independent, and that it is pickle-able. + score = _permutation_test_score(clone(estimator), X, y, labels, cv, scorer) + permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)( + delayed(_permutation_test_score)( + clone(estimator), X, _shuffle(y, labels, random_state), + labels, cv, scorer) + for _ in range(n_permutations)) + permutation_scores = np.array(permutation_scores) + pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1) + return score, permutation_scores, pvalue + + +permutation_test_score.__test__ = False # to avoid a pb with nosetests + + +def _permutation_test_score(estimator, X, y, labels, cv, scorer): + """Auxiliary function for permutation_test_score""" + avg_score = [] + for train, test in cv.split(X, y, labels): + estimator.fit(X[train], y[train]) + avg_score.append(scorer(estimator, X[test], y[test])) + return np.mean(avg_score) + + +def _shuffle(y, labels, random_state): + """Return a shuffled copy of y eventually shuffle among same labels.""" + if labels is None: + indices = random_state.permutation(len(y)) + else: + indices = np.arange(len(labels)) + for label in np.unique(labels): + this_mask = (labels == label) + indices[this_mask] = random_state.permutation(indices[this_mask]) + return y[indices] + + +def learning_curve(estimator, X, y, labels=None, + train_sizes=np.linspace(0.1, 1.0, 5), cv=None, scoring=None, + exploit_incremental_learning=False, n_jobs=1, + pre_dispatch="all", verbose=0): + """Learning curve. + + Determines cross-validated training and test scores for different training + set sizes. + + A cross-validation generator splits the whole dataset k times in training + and test data. Subsets of the training set with varying sizes will be used + to train the estimator and a score for each training subset size and the + test set will be computed. Afterwards, the scores will be averaged over + all k runs for each training subset size. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : object type that implements the "fit" and "predict" methods + An object of that type which is cloned for each validation. + + X : array-like, shape (n_samples, n_features) + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape (n_samples) or (n_samples, n_features), optional + Target relative to X for classification or regression; + None for unsupervised learning. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + train_sizes : array-like, shape (n_ticks,), dtype float or int + Relative or absolute numbers of training examples that will be used to + generate the learning curve. If the dtype is float, it is regarded as a + fraction of the maximum size of the training set (that is determined + by the selected validation method), i.e. it has to be within (0, 1]. + Otherwise it is interpreted as absolute sizes of the training sets. + Note that for classification the number of samples usually have to + be big enough to contain at least one sample from each class. + (default: np.linspace(0.1, 1.0, 5)) + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, ``StratifiedKFold`` is used for classification + tasks, when ``y`` is binary or multiclass. + + See the :mod:`sklearn.model_selection` module for the list of + cross-validation strategies that can be used here. + + Also refer :ref:`cross-validation documentation <_cross_validation>` + + scoring : string, callable or None, optional, default: None + A string (see model evaluation documentation) or + a scorer callable object / function with signature + ``scorer(estimator, X, y)``. + + exploit_incremental_learning : boolean, optional, default: False + If the estimator supports incremental learning, this will be + used to speed up fitting for different training set sizes. + + n_jobs : integer, optional + Number of jobs to run in parallel (default 1). + + pre_dispatch : integer or string, optional + Number of predispatched jobs for parallel execution (default is + all). The option can reduce the allocated memory. The string can + be an expression like '2*n_jobs'. + + verbose : integer, optional + Controls the verbosity: the higher, the more messages. + + Returns + ------- + train_sizes_abs : array, shape = (n_unique_ticks,), dtype int + Numbers of training examples that has been used to generate the + learning curve. Note that the number of ticks might be less + than n_ticks because duplicate entries will be removed. + + train_scores : array, shape (n_ticks, n_cv_folds) + Scores on training sets. + + test_scores : array, shape (n_ticks, n_cv_folds) + Scores on test set. + + Notes + ----- + See :ref:`examples/model_selection/plot_learning_curve.py + ` + """ + if exploit_incremental_learning and not hasattr(estimator, "partial_fit"): + raise ValueError("An estimator must support the partial_fit interface " + "to exploit incremental learning") + X, y, labels = indexable(X, y, labels) + + cv = check_cv(cv, y, classifier=is_classifier(estimator)) + cv_iter = cv.split(X, y, labels) + # Make a list since we will be iterating multiple times over the folds + cv_iter = list(cv_iter) + scorer = check_scoring(estimator, scoring=scoring) + + n_max_training_samples = len(cv_iter[0][0]) + # Because the lengths of folds can be significantly different, it is + # not guaranteed that we use all of the available training data when we + # use the first 'n_max_training_samples' samples. + train_sizes_abs = _translate_train_sizes(train_sizes, + n_max_training_samples) + n_unique_ticks = train_sizes_abs.shape[0] + if verbose > 0: + print("[learning_curve] Training set sizes: " + str(train_sizes_abs)) + + parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, + verbose=verbose) + if exploit_incremental_learning: + classes = np.unique(y) if is_classifier(estimator) else None + out = parallel(delayed(_incremental_fit_estimator)( + clone(estimator), X, y, classes, train, test, train_sizes_abs, + scorer, verbose) for train, test in cv.split(X, y, labels)) + else: + out = parallel(delayed(_fit_and_score)( + clone(estimator), X, y, scorer, train[:n_train_samples], test, + verbose, parameters=None, fit_params=None, return_train_score=True) + for train, test in cv_iter + for n_train_samples in train_sizes_abs) + out = np.array(out)[:, :2] + n_cv_folds = out.shape[0] // n_unique_ticks + out = out.reshape(n_cv_folds, n_unique_ticks, 2) + + out = np.asarray(out).transpose((2, 1, 0)) + + return train_sizes_abs, out[0], out[1] + + +def _translate_train_sizes(train_sizes, n_max_training_samples): + """Determine absolute sizes of training subsets and validate 'train_sizes'. + + Examples: + _translate_train_sizes([0.5, 1.0], 10) -> [5, 10] + _translate_train_sizes([5, 10], 10) -> [5, 10] + + Parameters + ---------- + train_sizes : array-like, shape (n_ticks,), dtype float or int + Numbers of training examples that will be used to generate the + learning curve. If the dtype is float, it is regarded as a + fraction of 'n_max_training_samples', i.e. it has to be within (0, 1]. + + n_max_training_samples : int + Maximum number of training samples (upper bound of 'train_sizes'). + + Returns + ------- + train_sizes_abs : array, shape (n_unique_ticks,), dtype int + Numbers of training examples that will be used to generate the + learning curve. Note that the number of ticks might be less + than n_ticks because duplicate entries will be removed. + """ + train_sizes_abs = np.asarray(train_sizes) + n_ticks = train_sizes_abs.shape[0] + n_min_required_samples = np.min(train_sizes_abs) + n_max_required_samples = np.max(train_sizes_abs) + if np.issubdtype(train_sizes_abs.dtype, np.float): + if n_min_required_samples <= 0.0 or n_max_required_samples > 1.0: + raise ValueError("train_sizes has been interpreted as fractions " + "of the maximum number of training samples and " + "must be within (0, 1], but is within [%f, %f]." + % (n_min_required_samples, + n_max_required_samples)) + train_sizes_abs = astype(train_sizes_abs * n_max_training_samples, + dtype=np.int, copy=False) + train_sizes_abs = np.clip(train_sizes_abs, 1, + n_max_training_samples) + else: + if (n_min_required_samples <= 0 or + n_max_required_samples > n_max_training_samples): + raise ValueError("train_sizes has been interpreted as absolute " + "numbers of training samples and must be within " + "(0, %d], but is within [%d, %d]." + % (n_max_training_samples, + n_min_required_samples, + n_max_required_samples)) + + train_sizes_abs = np.unique(train_sizes_abs) + if n_ticks > train_sizes_abs.shape[0]: + warnings.warn("Removed duplicate entries from 'train_sizes'. Number " + "of ticks will be less than than the size of " + "'train_sizes' %d instead of %d)." + % (train_sizes_abs.shape[0], n_ticks), RuntimeWarning) + + return train_sizes_abs + + +def _incremental_fit_estimator(estimator, X, y, classes, train, test, + train_sizes, scorer, verbose): + """Train estimator on training subsets incrementally and compute scores.""" + train_scores, test_scores = [], [] + partitions = zip(train_sizes, np.split(train, train_sizes)[:-1]) + for n_train_samples, partial_train in partitions: + train_subset = train[:n_train_samples] + X_train, y_train = _safe_split(estimator, X, y, train_subset) + X_partial_train, y_partial_train = _safe_split(estimator, X, y, + partial_train) + X_test, y_test = _safe_split(estimator, X, y, test, train_subset) + if y_partial_train is None: + estimator.partial_fit(X_partial_train, classes=classes) + else: + estimator.partial_fit(X_partial_train, y_partial_train, + classes=classes) + train_scores.append(_score(estimator, X_train, y_train, scorer)) + test_scores.append(_score(estimator, X_test, y_test, scorer)) + return np.array((train_scores, test_scores)).T + + +def validation_curve(estimator, X, y, param_name, param_range, labels=None, + cv=None, scoring=None, n_jobs=1, pre_dispatch="all", + verbose=0): + """Validation curve. + + Determine training and test scores for varying parameter values. + + Compute scores for an estimator with different values of a specified + parameter. This is similar to grid search with one parameter. However, this + will also compute training scores and is merely a utility for plotting the + results. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : object type that implements the "fit" and "predict" methods + An object of that type which is cloned for each validation. + + X : array-like, shape (n_samples, n_features) + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape (n_samples) or (n_samples, n_features), optional + Target relative to X for classification or regression; + None for unsupervised learning. + + param_name : string + Name of the parameter that will be varied. + + param_range : array-like, shape (n_values,) + The values of the parameter that will be evaluated. + + labels : array-like, with shape (n_samples,), optional + Group labels for the samples used while splitting the dataset into + train/test set. + + cv : int, cross-validation generator or an iterable, optional + Determines the cross-validation splitting strategy. + Possible inputs for cv are: + - None, to use the default 3-fold cross validation, + - integer, to specify the number of folds in a `(Stratified)KFold`, + - An object to be used as a cross-validation generator. + - An iterable yielding train, test splits. + + For integer/None inputs, ``StratifiedKFold`` is used for classification + tasks, when ``y`` is binary or multiclass. + + See the :mod:`sklearn.model_selection` module for the list of + cross-validation strategies that can be used here. + + Also refer :ref:`cross-validation documentation <_cross_validation>` + + scoring : string, callable or None, optional, default: None + A string (see model evaluation documentation) or + a scorer callable object / function with signature + ``scorer(estimator, X, y)``. + + n_jobs : integer, optional + Number of jobs to run in parallel (default 1). + + pre_dispatch : integer or string, optional + Number of predispatched jobs for parallel execution (default is + all). The option can reduce the allocated memory. The string can + be an expression like '2*n_jobs'. + + verbose : integer, optional + Controls the verbosity: the higher, the more messages. + + Returns + ------- + train_scores : array, shape (n_ticks, n_cv_folds) + Scores on training sets. + + test_scores : array, shape (n_ticks, n_cv_folds) + Scores on test set. + + Notes + ----- + See + :ref:`examples/model_selection/plot_validation_curve.py + ` + """ + X, y, labels = indexable(X, y, labels) + + cv = check_cv(cv, y, classifier=is_classifier(estimator)) + + scorer = check_scoring(estimator, scoring=scoring) + + parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, + verbose=verbose) + out = parallel(delayed(_fit_and_score)( + estimator, X, y, scorer, train, test, verbose, + parameters={param_name: v}, fit_params=None, return_train_score=True) + for train, test in cv.split(X, y, labels) for v in param_range) + + out = np.asarray(out)[:, :2] + n_params = len(param_range) + n_cv_folds = out.shape[0] // n_params + out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0)) + + return out[0], out[1] diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py new file mode 100644 index 0000000000000..91d81873a1d8e --- /dev/null +++ b/sklearn/model_selection/tests/test_search.py @@ -0,0 +1,819 @@ +"""Test the search module""" + +from collections import Iterable, Sized +from sklearn.externals.six.moves import cStringIO as StringIO +from sklearn.externals.six.moves import xrange +from itertools import chain, product +import pickle +import sys + +import numpy as np +import scipy.sparse as sp + +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_not_equal +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import assert_false, assert_true +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_no_warnings +from sklearn.utils.testing import ignore_warnings +from sklearn.utils.mocking import CheckingClassifier, MockDataFrame + +from scipy.stats import bernoulli, expon, uniform + +from sklearn.externals.six.moves import zip +from sklearn.base import BaseEstimator +from sklearn.datasets import make_classification +from sklearn.datasets import make_blobs +from sklearn.datasets import make_multilabel_classification + +from sklearn.model_selection import KFold +from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import StratifiedShuffleSplit +from sklearn.model_selection import LeaveOneLabelOut +from sklearn.model_selection import LeavePLabelOut +from sklearn.model_selection import LabelKFold +from sklearn.model_selection import LabelShuffleSplit +from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import RandomizedSearchCV +from sklearn.model_selection import ParameterGrid +from sklearn.model_selection import ParameterSampler + +# TODO Import from sklearn.exceptions once merged. +from sklearn.base import ChangedBehaviorWarning +from sklearn.model_selection._validation import FitFailedWarning + +from sklearn.svm import LinearSVC, SVC +from sklearn.tree import DecisionTreeRegressor +from sklearn.tree import DecisionTreeClassifier +from sklearn.cluster import KMeans +from sklearn.neighbors import KernelDensity +from sklearn.metrics import f1_score +from sklearn.metrics import make_scorer +from sklearn.metrics import roc_auc_score +from sklearn.preprocessing import Imputer +from sklearn.pipeline import Pipeline + + +# Neither of the following two estimators inherit from BaseEstimator, +# to test hyperparameter search on user-defined classifiers. +class MockClassifier(object): + """Dummy classifier to test the parameter search algorithms""" + def __init__(self, foo_param=0): + self.foo_param = foo_param + + def fit(self, X, Y): + assert_true(len(X) == len(Y)) + return self + + def predict(self, T): + return T.shape[0] + + predict_proba = predict + decision_function = predict + transform = predict + + def score(self, X=None, Y=None): + if self.foo_param > 1: + score = 1. + else: + score = 0. + return score + + def get_params(self, deep=False): + return {'foo_param': self.foo_param} + + def set_params(self, **params): + self.foo_param = params['foo_param'] + return self + + +class LinearSVCNoScore(LinearSVC): + """An LinearSVC classifier that has no score method.""" + @property + def score(self): + raise AttributeError + +X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) +y = np.array([1, 1, 2, 2]) + + +def assert_grid_iter_equals_getitem(grid): + assert_equal(list(grid), [grid[i] for i in range(len(grid))]) + + +def test_parameter_grid(): + # Test basic properties of ParameterGrid. + params1 = {"foo": [1, 2, 3]} + grid1 = ParameterGrid(params1) + assert_true(isinstance(grid1, Iterable)) + assert_true(isinstance(grid1, Sized)) + assert_equal(len(grid1), 3) + assert_grid_iter_equals_getitem(grid1) + + params2 = {"foo": [4, 2], + "bar": ["ham", "spam", "eggs"]} + grid2 = ParameterGrid(params2) + assert_equal(len(grid2), 6) + + # loop to assert we can iterate over the grid multiple times + for i in xrange(2): + # tuple + chain transforms {"a": 1, "b": 2} to ("a", 1, "b", 2) + points = set(tuple(chain(*(sorted(p.items())))) for p in grid2) + assert_equal(points, + set(("bar", x, "foo", y) + for x, y in product(params2["bar"], params2["foo"]))) + assert_grid_iter_equals_getitem(grid2) + + # Special case: empty grid (useful to get default estimator settings) + empty = ParameterGrid({}) + assert_equal(len(empty), 1) + assert_equal(list(empty), [{}]) + assert_grid_iter_equals_getitem(empty) + assert_raises(IndexError, lambda: empty[1]) + + has_empty = ParameterGrid([{'C': [1, 10]}, {}, {'C': [.5]}]) + assert_equal(len(has_empty), 4) + assert_equal(list(has_empty), [{'C': 1}, {'C': 10}, {}, {'C': .5}]) + assert_grid_iter_equals_getitem(has_empty) + + +def test_grid_search(): + # Test that the best estimator contains the right value for foo_param + clf = MockClassifier() + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, verbose=3) + # make sure it selects the smallest parameter in case of ties + old_stdout = sys.stdout + sys.stdout = StringIO() + grid_search.fit(X, y) + sys.stdout = old_stdout + assert_equal(grid_search.best_estimator_.foo_param, 2) + + for i, foo_i in enumerate([1, 2, 3]): + assert_true(grid_search.grid_scores_[i][0] + == {'foo_param': foo_i}) + # Smoke test the score etc: + grid_search.score(X, y) + grid_search.predict_proba(X) + grid_search.decision_function(X) + grid_search.transform(X) + + # Test exception handling on scoring + grid_search.scoring = 'sklearn' + assert_raises(ValueError, grid_search.fit, X, y) + + +@ignore_warnings +def test_grid_search_no_score(): + # Test grid-search on classifier that has no score function. + clf = LinearSVC(random_state=0) + X, y = make_blobs(random_state=0, centers=2) + Cs = [.1, 1, 10] + clf_no_score = LinearSVCNoScore(random_state=0) + grid_search = GridSearchCV(clf, {'C': Cs}, scoring='accuracy') + grid_search.fit(X, y) + + grid_search_no_score = GridSearchCV(clf_no_score, {'C': Cs}, + scoring='accuracy') + # smoketest grid search + grid_search_no_score.fit(X, y) + + # check that best params are equal + assert_equal(grid_search_no_score.best_params_, grid_search.best_params_) + # check that we can call score and that it gives the correct result + assert_equal(grid_search.score(X, y), grid_search_no_score.score(X, y)) + + # giving no scoring function raises an error + grid_search_no_score = GridSearchCV(clf_no_score, {'C': Cs}) + assert_raise_message(TypeError, "no scoring", grid_search_no_score.fit, + [[1]]) + + +def test_grid_search_score_method(): + X, y = make_classification(n_samples=100, n_classes=2, flip_y=.2, + random_state=0) + clf = LinearSVC(random_state=0) + grid = {'C': [.1]} + + search_no_scoring = GridSearchCV(clf, grid, scoring=None).fit(X, y) + search_accuracy = GridSearchCV(clf, grid, scoring='accuracy').fit(X, y) + search_no_score_method_auc = GridSearchCV(LinearSVCNoScore(), grid, + scoring='roc_auc').fit(X, y) + search_auc = GridSearchCV(clf, grid, scoring='roc_auc').fit(X, y) + + # Check warning only occurs in situation where behavior changed: + # estimator requires score method to compete with scoring parameter + score_no_scoring = assert_no_warnings(search_no_scoring.score, X, y) + score_accuracy = assert_warns(ChangedBehaviorWarning, + search_accuracy.score, X, y) + score_no_score_auc = assert_no_warnings(search_no_score_method_auc.score, + X, y) + score_auc = assert_warns(ChangedBehaviorWarning, + search_auc.score, X, y) + # ensure the test is sane + assert_true(score_auc < 1.0) + assert_true(score_accuracy < 1.0) + assert_not_equal(score_auc, score_accuracy) + + assert_almost_equal(score_accuracy, score_no_scoring) + assert_almost_equal(score_auc, score_no_score_auc) + + +def test_grid_search_labels(): + # Check if ValueError (when labels is None) propagates to GridSearchCV + # And also check if labels is correctly passed to the cv object + rng = np.random.RandomState(0) + + X, y = make_classification(n_samples=15, n_classes=2, random_state=0) + labels = rng.randint(0, 3, 15) + + clf = LinearSVC(random_state=0) + grid = {'C': [1]} + + label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(), + LabelShuffleSplit()] + for cv in label_cvs: + gs = GridSearchCV(clf, grid, cv=cv) + assert_raise_message(ValueError, + "The labels parameter should not be None", + gs.fit, X, y) + gs.fit(X, y, labels) + + non_label_cvs = [StratifiedKFold(), StratifiedShuffleSplit()] + for cv in non_label_cvs: + print(cv) + gs = GridSearchCV(clf, grid, cv=cv) + # Should not raise an error + gs.fit(X, y) + + +def test_trivial_grid_scores(): + # Test search over a "grid" with only one point. + # Non-regression test: grid_scores_ wouldn't be set by GridSearchCV. + clf = MockClassifier() + grid_search = GridSearchCV(clf, {'foo_param': [1]}) + grid_search.fit(X, y) + assert_true(hasattr(grid_search, "grid_scores_")) + + random_search = RandomizedSearchCV(clf, {'foo_param': [0]}, n_iter=1) + random_search.fit(X, y) + assert_true(hasattr(random_search, "grid_scores_")) + + +def test_no_refit(): + # Test that grid search can be used for model selection only + clf = MockClassifier() + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=False) + grid_search.fit(X, y) + assert_true(hasattr(grid_search, "best_params_")) + + +def test_grid_search_error(): + # Test that grid search will capture errors on data with different + # length + X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) + + clf = LinearSVC() + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}) + assert_raises(ValueError, cv.fit, X_[:180], y_) + + +def test_grid_search_iid(): + # test the iid parameter + # noise-free simple 2d-data + X, y = make_blobs(centers=[[0, 0], [1, 0], [0, 1], [1, 1]], random_state=0, + cluster_std=0.1, shuffle=False, n_samples=80) + # split dataset into two folds that are not iid + # first one contains data of all 4 blobs, second only from two. + mask = np.ones(X.shape[0], dtype=np.bool) + mask[np.where(y == 1)[0][::2]] = 0 + mask[np.where(y == 2)[0][::2]] = 0 + # this leads to perfect classification on one fold and a score of 1/3 on + # the other + svm = SVC(kernel='linear') + # create "cv" for splits + cv = [[mask, ~mask], [~mask, mask]] + # once with iid=True (default) + grid_search = GridSearchCV(svm, param_grid={'C': [1, 10]}, cv=cv) + grid_search.fit(X, y) + first = grid_search.grid_scores_[0] + assert_equal(first.parameters['C'], 1) + assert_array_almost_equal(first.cv_validation_scores, [1, 1. / 3.]) + # for first split, 1/4 of dataset is in test, for second 3/4. + # take weighted average + assert_almost_equal(first.mean_validation_score, + 1 * 1. / 4. + 1. / 3. * 3. / 4.) + + # once with iid=False + grid_search = GridSearchCV(svm, param_grid={'C': [1, 10]}, cv=cv, + iid=False) + grid_search.fit(X, y) + first = grid_search.grid_scores_[0] + assert_equal(first.parameters['C'], 1) + # scores are the same as above + assert_array_almost_equal(first.cv_validation_scores, [1, 1. / 3.]) + # averaged score is just mean of scores + assert_almost_equal(first.mean_validation_score, + np.mean(first.cv_validation_scores)) + + +def test_grid_search_one_grid_point(): + X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) + param_dict = {"C": [1.0], "kernel": ["rbf"], "gamma": [0.1]} + + clf = SVC() + cv = GridSearchCV(clf, param_dict) + cv.fit(X_, y_) + + clf = SVC(C=1.0, kernel="rbf", gamma=0.1) + clf.fit(X_, y_) + + assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_) + + +def test_grid_search_bad_param_grid(): + param_dict = {"C": 1.0} + clf = SVC() + assert_raises(ValueError, GridSearchCV, clf, param_dict) + + param_dict = {"C": []} + clf = SVC() + assert_raises(ValueError, GridSearchCV, clf, param_dict) + + param_dict = {"C": np.ones(6).reshape(3, 2)} + clf = SVC() + assert_raises(ValueError, GridSearchCV, clf, param_dict) + + +def test_grid_search_sparse(): + # Test that grid search works with both dense and sparse matrices + X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) + + clf = LinearSVC() + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}) + cv.fit(X_[:180], y_[:180]) + y_pred = cv.predict(X_[180:]) + C = cv.best_estimator_.C + + X_ = sp.csr_matrix(X_) + clf = LinearSVC() + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}) + cv.fit(X_[:180].tocoo(), y_[:180]) + y_pred2 = cv.predict(X_[180:]) + C2 = cv.best_estimator_.C + + assert_true(np.mean(y_pred == y_pred2) >= .9) + assert_equal(C, C2) + + +def test_grid_search_sparse_scoring(): + X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) + + clf = LinearSVC() + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") + cv.fit(X_[:180], y_[:180]) + y_pred = cv.predict(X_[180:]) + C = cv.best_estimator_.C + + X_ = sp.csr_matrix(X_) + clf = LinearSVC() + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") + cv.fit(X_[:180], y_[:180]) + y_pred2 = cv.predict(X_[180:]) + C2 = cv.best_estimator_.C + + assert_array_equal(y_pred, y_pred2) + assert_equal(C, C2) + # Smoke test the score + # np.testing.assert_allclose(f1_score(cv.predict(X_[:180]), y[:180]), + # cv.score(X_[:180], y[:180])) + + # test loss where greater is worse + def f1_loss(y_true_, y_pred_): + return -f1_score(y_true_, y_pred_) + F1Loss = make_scorer(f1_loss, greater_is_better=False) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring=F1Loss) + cv.fit(X_[:180], y_[:180]) + y_pred3 = cv.predict(X_[180:]) + C3 = cv.best_estimator_.C + + assert_equal(C, C3) + assert_array_equal(y_pred, y_pred3) + + +def test_grid_search_precomputed_kernel(): + # Test that grid search works when the input features are given in the + # form of a precomputed kernel matrix + X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) + + # compute the training kernel matrix corresponding to the linear kernel + K_train = np.dot(X_[:180], X_[:180].T) + y_train = y_[:180] + + clf = SVC(kernel='precomputed') + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}) + cv.fit(K_train, y_train) + + assert_true(cv.best_score_ >= 0) + + # compute the test kernel matrix + K_test = np.dot(X_[180:], X_[:180].T) + y_test = y_[180:] + + y_pred = cv.predict(K_test) + + assert_true(np.mean(y_pred == y_test) >= 0) + + # test error is raised when the precomputed kernel is not array-like + # or sparse + assert_raises(ValueError, cv.fit, K_train.tolist(), y_train) + + +def test_grid_search_precomputed_kernel_error_nonsquare(): + # Test that grid search returns an error with a non-square precomputed + # training kernel matrix + K_train = np.zeros((10, 20)) + y_train = np.ones((10, )) + clf = SVC(kernel='precomputed') + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}) + assert_raises(ValueError, cv.fit, K_train, y_train) + + +def test_grid_search_precomputed_kernel_error_kernel_function(): + # Test that grid search returns an error when using a kernel_function + X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) + kernel_function = lambda x1, x2: np.dot(x1, x2.T) + clf = SVC(kernel=kernel_function) + cv = GridSearchCV(clf, {'C': [0.1, 1.0]}) + assert_raises(ValueError, cv.fit, X_, y_) + + +class BrokenClassifier(BaseEstimator): + """Broken classifier that cannot be fit twice""" + + def __init__(self, parameter=None): + self.parameter = parameter + + def fit(self, X, y): + assert_true(not hasattr(self, 'has_been_fit_')) + self.has_been_fit_ = True + + def predict(self, X): + return np.zeros(X.shape[0]) + + +@ignore_warnings +def test_refit(): + # Regression test for bug in refitting + # Simulates re-fitting a broken estimator; this used to break with + # sparse SVMs. + X = np.arange(100).reshape(10, 10) + y = np.array([0] * 5 + [1] * 5) + + clf = GridSearchCV(BrokenClassifier(), [{'parameter': [0, 1]}], + scoring="precision", refit=True) + clf.fit(X, y) + + +def test_gridsearch_nd(): + # Pass X as list in GridSearchCV + X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2) + y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11) + check_X = lambda x: x.shape[1:] == (5, 3, 2) + check_y = lambda x: x.shape[1:] == (7, 11) + clf = CheckingClassifier(check_X=check_X, check_y=check_y) + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) + grid_search.fit(X_4d, y_3d).score(X, y) + assert_true(hasattr(grid_search, "grid_scores_")) + + +def test_X_as_list(): + # Pass X as list in GridSearchCV + X = np.arange(100).reshape(10, 10) + y = np.array([0] * 5 + [1] * 5) + + clf = CheckingClassifier(check_X=lambda x: isinstance(x, list)) + cv = KFold(n_folds=3) + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) + grid_search.fit(X.tolist(), y).score(X, y) + assert_true(hasattr(grid_search, "grid_scores_")) + + +def test_y_as_list(): + # Pass y as list in GridSearchCV + X = np.arange(100).reshape(10, 10) + y = np.array([0] * 5 + [1] * 5) + + clf = CheckingClassifier(check_y=lambda x: isinstance(x, list)) + cv = KFold(n_folds=3) + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) + grid_search.fit(X, y.tolist()).score(X, y) + assert_true(hasattr(grid_search, "grid_scores_")) + + +@ignore_warnings +def test_pandas_input(): + # check cross_val_score doesn't destroy pandas dataframe + types = [(MockDataFrame, MockDataFrame)] + try: + from pandas import Series, DataFrame + types.append((DataFrame, Series)) + except ImportError: + pass + + X = np.arange(100).reshape(10, 10) + y = np.array([0] * 5 + [1] * 5) + + for InputFeatureType, TargetType in types: + # X dataframe, y series + X_df, y_ser = InputFeatureType(X), TargetType(y) + check_df = lambda x: isinstance(x, InputFeatureType) + check_series = lambda x: isinstance(x, TargetType) + clf = CheckingClassifier(check_X=check_df, check_y=check_series) + + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) + grid_search.fit(X_df, y_ser).score(X_df, y_ser) + grid_search.predict(X_df) + assert_true(hasattr(grid_search, "grid_scores_")) + + +def test_unsupervised_grid_search(): + # test grid-search with unsupervised estimator + X, y = make_blobs(random_state=0) + km = KMeans(random_state=0) + grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), + scoring='adjusted_rand_score') + grid_search.fit(X, y) + # ARI can find the right number :) + assert_equal(grid_search.best_params_["n_clusters"], 3) + + # Now without a score, and without y + grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4])) + grid_search.fit(X) + assert_equal(grid_search.best_params_["n_clusters"], 4) + + +def test_gridsearch_no_predict(): + # test grid-search with an estimator without predict. + # slight duplication of a test from KDE + def custom_scoring(estimator, X): + return 42 if estimator.bandwidth == .1 else 0 + X, _ = make_blobs(cluster_std=.1, random_state=1, + centers=[[0, 1], [1, 0], [0, 0]]) + search = GridSearchCV(KernelDensity(), + param_grid=dict(bandwidth=[.01, .1, 1]), + scoring=custom_scoring) + search.fit(X) + assert_equal(search.best_params_['bandwidth'], .1) + assert_equal(search.best_score_, 42) + + +def test_param_sampler(): + # test basic properties of param sampler + param_distributions = {"kernel": ["rbf", "linear"], + "C": uniform(0, 1)} + sampler = ParameterSampler(param_distributions=param_distributions, + n_iter=10, random_state=0) + samples = [x for x in sampler] + assert_equal(len(samples), 10) + for sample in samples: + assert_true(sample["kernel"] in ["rbf", "linear"]) + assert_true(0 <= sample["C"] <= 1) + + +def test_randomized_search_grid_scores(): + # Make a dataset with a lot of noise to get various kind of prediction + # errors across CV folds and parameter settings + X, y = make_classification(n_samples=200, n_features=100, n_informative=3, + random_state=0) + + # XXX: as of today (scipy 0.12) it's not possible to set the random seed + # of scipy.stats distributions: the assertions in this test should thus + # not depend on the randomization + params = dict(C=expon(scale=10), + gamma=expon(scale=0.1)) + n_cv_iter = 3 + n_search_iter = 30 + search = RandomizedSearchCV(SVC(), n_iter=n_search_iter, cv=n_cv_iter, + param_distributions=params, iid=False) + search.fit(X, y) + assert_equal(len(search.grid_scores_), n_search_iter) + + # Check consistency of the structure of each cv_score item + for cv_score in search.grid_scores_: + assert_equal(len(cv_score.cv_validation_scores), n_cv_iter) + # Because we set iid to False, the mean_validation score is the + # mean of the fold mean scores instead of the aggregate sample-wise + # mean score + assert_almost_equal(np.mean(cv_score.cv_validation_scores), + cv_score.mean_validation_score) + assert_equal(list(sorted(cv_score.parameters.keys())), + list(sorted(params.keys()))) + + # Check the consistency with the best_score_ and best_params_ attributes + sorted_grid_scores = list(sorted(search.grid_scores_, + key=lambda x: x.mean_validation_score)) + best_score = sorted_grid_scores[-1].mean_validation_score + assert_equal(search.best_score_, best_score) + + tied_best_params = [s.parameters for s in sorted_grid_scores + if s.mean_validation_score == best_score] + assert_true(search.best_params_ in tied_best_params, + "best_params_={0} is not part of the" + " tied best models: {1}".format( + search.best_params_, tied_best_params)) + + +def test_grid_search_score_consistency(): + # test that correct scores are used + clf = LinearSVC(random_state=0) + X, y = make_blobs(random_state=0, centers=2) + Cs = [.1, 1, 10] + for score in ['f1', 'roc_auc']: + grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score) + grid_search.fit(X, y) + cv = StratifiedKFold(n_folds=3) + for C, scores in zip(Cs, grid_search.grid_scores_): + clf.set_params(C=C) + scores = scores[2] # get the separate runs from grid scores + i = 0 + for train, test in cv.split(X, y): + clf.fit(X[train], y[train]) + if score == "f1": + correct_score = f1_score(y[test], clf.predict(X[test])) + elif score == "roc_auc": + dec = clf.decision_function(X[test]) + correct_score = roc_auc_score(y[test], dec) + assert_almost_equal(correct_score, scores[i]) + i += 1 + + +def test_pickle(): + # Test that a fit search can be pickled + clf = MockClassifier() + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True) + grid_search.fit(X, y) + pickle.dumps(grid_search) # smoke test + + random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]}, + refit=True, n_iter=3) + random_search.fit(X, y) + pickle.dumps(random_search) # smoke test + + +def test_grid_search_with_multioutput_data(): + # Test search with multi-output estimator + + X, y = make_multilabel_classification(return_indicator=True, + random_state=0) + + est_parameters = {"max_depth": [1, 2, 3, 4]} + cv = KFold(random_state=0) + + estimators = [DecisionTreeRegressor(random_state=0), + DecisionTreeClassifier(random_state=0)] + + # Test with grid search cv + for est in estimators: + grid_search = GridSearchCV(est, est_parameters, cv=cv) + grid_search.fit(X, y) + for parameters, _, cv_validation_scores in grid_search.grid_scores_: + est.set_params(**parameters) + + for i, (train, test) in enumerate(cv.split(X, y)): + est.fit(X[train], y[train]) + correct_score = est.score(X[test], y[test]) + assert_almost_equal(correct_score, + cv_validation_scores[i]) + + # Test with a randomized search + for est in estimators: + random_search = RandomizedSearchCV(est, est_parameters, + cv=cv, n_iter=3) + random_search.fit(X, y) + for parameters, _, cv_validation_scores in random_search.grid_scores_: + est.set_params(**parameters) + + for i, (train, test) in enumerate(cv.split(X, y)): + est.fit(X[train], y[train]) + correct_score = est.score(X[test], y[test]) + assert_almost_equal(correct_score, + cv_validation_scores[i]) + + +def test_predict_proba_disabled(): + # Test predict_proba when disabled on estimator. + X = np.arange(20).reshape(5, -1) + y = [0, 0, 1, 1, 1] + clf = SVC(probability=False) + gs = GridSearchCV(clf, {}, cv=2).fit(X, y) + assert_false(hasattr(gs, "predict_proba")) + + +def test_grid_search_allows_nans(): + # Test GridSearchCV with Imputer + X = np.arange(20, dtype=np.float64).reshape(5, -1) + X[2, :] = np.nan + y = [0, 0, 1, 1, 1] + p = Pipeline([ + ('imputer', Imputer(strategy='mean', missing_values='NaN')), + ('classifier', MockClassifier()), + ]) + GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y) + + +class FailingClassifier(BaseEstimator): + """Classifier that raises a ValueError on fit()""" + + FAILING_PARAMETER = 2 + + def __init__(self, parameter=None): + self.parameter = parameter + + def fit(self, X, y=None): + if self.parameter == FailingClassifier.FAILING_PARAMETER: + raise ValueError("Failing classifier failed as required") + + def predict(self, X): + return np.zeros(X.shape[0]) + + +def test_grid_search_failing_classifier(): + # GridSearchCV with on_error != 'raise' + # Ensures that a warning is raised and score reset where appropriate. + + X, y = make_classification(n_samples=20, n_features=10, random_state=0) + + clf = FailingClassifier() + + # refit=False because we only want to check that errors caused by fits + # to individual folds will be caught and warnings raised instead. If + # refit was done, then an exception would be raised on refit and not + # caught by grid_search (expected behavior), and this would cause an + # error in this test. + gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', + refit=False, error_score=0.0) + + assert_warns(FitFailedWarning, gs.fit, X, y) + + # Ensure that grid scores were set to zero as required for those fits + # that are expected to fail. + assert all(np.all(this_point.cv_validation_scores == 0.0) + for this_point in gs.grid_scores_ + if this_point.parameters['parameter'] == + FailingClassifier.FAILING_PARAMETER) + + gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', + refit=False, error_score=float('nan')) + assert_warns(FitFailedWarning, gs.fit, X, y) + assert all(np.all(np.isnan(this_point.cv_validation_scores)) + for this_point in gs.grid_scores_ + if this_point.parameters['parameter'] == + FailingClassifier.FAILING_PARAMETER) + + +def test_grid_search_failing_classifier_raise(): + # GridSearchCV with on_error == 'raise' raises the error + + X, y = make_classification(n_samples=20, n_features=10, random_state=0) + + clf = FailingClassifier() + + # refit=False because we want to test the behaviour of the grid search part + gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', + refit=False, error_score='raise') + + # FailingClassifier issues a ValueError so this is what we look for. + assert_raises(ValueError, gs.fit, X, y) + + +def test_parameters_sampler_replacement(): + # raise error if n_iter too large + params = {'first': [0, 1], 'second': ['a', 'b', 'c']} + sampler = ParameterSampler(params, n_iter=7) + assert_raises(ValueError, list, sampler) + # degenerates to GridSearchCV if n_iter the same as grid_size + sampler = ParameterSampler(params, n_iter=6) + samples = list(sampler) + assert_equal(len(samples), 6) + for values in ParameterGrid(params): + assert_true(values in samples) + + # test sampling without replacement in a large grid + params = {'a': range(10), 'b': range(10), 'c': range(10)} + sampler = ParameterSampler(params, n_iter=99, random_state=42) + samples = list(sampler) + assert_equal(len(samples), 99) + hashable_samples = ["a%db%dc%d" % (p['a'], p['b'], p['c']) + for p in samples] + assert_equal(len(set(hashable_samples)), 99) + + # doesn't go into infinite loops + params_distribution = {'first': bernoulli(.5), 'second': ['a', 'b', 'c']} + sampler = ParameterSampler(params_distribution, n_iter=7) + samples = list(sampler) + assert_equal(len(samples), 7) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py new file mode 100644 index 0000000000000..4689ead6007d0 --- /dev/null +++ b/sklearn/model_selection/tests/test_split.py @@ -0,0 +1,968 @@ +"""Test the split module""" +from __future__ import division +import warnings + +import numpy as np +from scipy.sparse import coo_matrix +from scipy import stats +from scipy.misc import comb +from itertools import combinations +from sklearn.utils.fixes import combinations_with_replacement + +from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_false +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_raises_regexp +from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_greater_equal +from sklearn.utils.testing import assert_not_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_warns_message +from sklearn.utils.testing import ignore_warnings +from sklearn.utils.validation import _num_samples +from sklearn.utils.mocking import MockDataFrame + +from sklearn.model_selection import cross_val_score +from sklearn.model_selection import KFold +from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import LabelKFold +from sklearn.model_selection import LeaveOneOut +from sklearn.model_selection import LeaveOneLabelOut +from sklearn.model_selection import LeavePOut +from sklearn.model_selection import LeavePLabelOut +from sklearn.model_selection import ShuffleSplit +from sklearn.model_selection import LabelShuffleSplit +from sklearn.model_selection import StratifiedShuffleSplit +from sklearn.model_selection import PredefinedSplit +from sklearn.model_selection import check_cv +from sklearn.model_selection import train_test_split +from sklearn.model_selection import GridSearchCV + +from sklearn.svm import LinearSVC + +from sklearn.model_selection._split import _safe_split +from sklearn.model_selection._split import _validate_shuffle_split +from sklearn.model_selection._split import _CVIterableWrapper +from sklearn.model_selection._split import _build_repr + +from sklearn.datasets import load_digits +from sklearn.datasets import load_iris +from sklearn.datasets import make_classification + +from sklearn.externals import six +from sklearn.externals.six.moves import zip + +from sklearn.svm import SVC + +X = np.ones(10) +y = np.arange(10) // 2 +P_sparse = coo_matrix(np.eye(5)) +iris = load_iris() +digits = load_digits() + + +class MockClassifier(object): + """Dummy classifier to test the cross-validation""" + + def __init__(self, a=0, allow_nd=False): + self.a = a + self.allow_nd = allow_nd + + def fit(self, X, Y=None, sample_weight=None, class_prior=None, + sparse_sample_weight=None, sparse_param=None, dummy_int=None, + dummy_str=None, dummy_obj=None, callback=None): + """The dummy arguments are to test that this fit function can + accept non-array arguments through cross-validation, such as: + - int + - str (this is actually array-like) + - object + - function + """ + self.dummy_int = dummy_int + self.dummy_str = dummy_str + self.dummy_obj = dummy_obj + if callback is not None: + callback(self) + + if self.allow_nd: + X = X.reshape(len(X), -1) + if X.ndim >= 3 and not self.allow_nd: + raise ValueError('X cannot be d') + if sample_weight is not None: + assert_true(sample_weight.shape[0] == X.shape[0], + 'MockClassifier extra fit_param sample_weight.shape[0]' + ' is {0}, should be {1}'.format(sample_weight.shape[0], + X.shape[0])) + if class_prior is not None: + assert_true(class_prior.shape[0] == len(np.unique(y)), + 'MockClassifier extra fit_param class_prior.shape[0]' + ' is {0}, should be {1}'.format(class_prior.shape[0], + len(np.unique(y)))) + if sparse_sample_weight is not None: + fmt = ('MockClassifier extra fit_param sparse_sample_weight' + '.shape[0] is {0}, should be {1}') + assert_true(sparse_sample_weight.shape[0] == X.shape[0], + fmt.format(sparse_sample_weight.shape[0], X.shape[0])) + if sparse_param is not None: + fmt = ('MockClassifier extra fit_param sparse_param.shape ' + 'is ({0}, {1}), should be ({2}, {3})') + assert_true(sparse_param.shape == P_sparse.shape, + fmt.format(sparse_param.shape[0], + sparse_param.shape[1], + P_sparse.shape[0], P_sparse.shape[1])) + return self + + def predict(self, T): + if self.allow_nd: + T = T.reshape(len(T), -1) + return T[:, 0] + + def score(self, X=None, Y=None): + return 1. / (1 + np.abs(self.a)) + + def get_params(self, deep=False): + return {'a': self.a, 'allow_nd': self.allow_nd} + + +@ignore_warnings +def test_cross_validator_with_default_indices(): + n_samples = 4 + n_unique_labels = 4 + n_folds = 2 + p = 2 + n_iter = 10 # (the default value) + + X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + X_1d = np.array([1, 2, 3, 4]) + y = np.array([1, 1, 2, 2]) + labels = np.array([1, 2, 3, 4]) + loo = LeaveOneOut() + lpo = LeavePOut(p) + kf = KFold(n_folds) + skf = StratifiedKFold(n_folds) + lolo = LeaveOneLabelOut() + lopo = LeavePLabelOut(p) + ss = ShuffleSplit(random_state=0) + ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2 + + n_splits = [n_samples, comb(n_samples, p), n_folds, n_folds, + n_unique_labels, comb(n_unique_labels, p), n_iter, 2] + + for i, cv in enumerate([loo, lpo, kf, skf, lolo, lopo, ss, ps]): + # Test if get_n_splits works correctly + assert_equal(n_splits[i], cv.get_n_splits(X, y, labels)) + + # Test if the cross-validator works as expected even if + # the data is 1d + np.testing.assert_equal(list(cv.split(X, y, labels)), + list(cv.split(X_1d, y, labels))) + # Test that train, test indices returned are integers + for train, test in cv.split(X, y, labels): + assert_equal(np.asarray(train).dtype.kind, 'i') + assert_equal(np.asarray(train).dtype.kind, 'i') + + +def check_valid_split(train, test, n_samples=None): + # Use python sets to get more informative assertion failure messages + train, test = set(train), set(test) + + # Train and test split should not overlap + assert_equal(train.intersection(test), set()) + + if n_samples is not None: + # Check that the union of train an test split cover all the indices + assert_equal(train.union(test), set(range(n_samples))) + + +def check_cv_coverage(cv, X, y, labels, expected_n_iter=None): + n_samples = _num_samples(X) + # Check that a all the samples appear at least once in a test fold + if expected_n_iter is not None: + assert_equal(cv.get_n_splits(X, y, labels), expected_n_iter) + else: + expected_n_iter = cv.get_n_splits(X, y, labels) + + collected_test_samples = set() + iterations = 0 + for train, test in cv.split(X, y, labels): + check_valid_split(train, test, n_samples=n_samples) + iterations += 1 + collected_test_samples.update(test) + + # Check that the accumulated test samples cover the whole dataset + assert_equal(iterations, expected_n_iter) + if n_samples is not None: + assert_equal(collected_test_samples, set(range(n_samples))) + + +def test_kfold_valueerrors(): + X1 = np.array([[1, 2], [3, 4], [5, 6]]) + X2 = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + # Check that errors are raised if there is not enough samples + assert_raises(ValueError, next, KFold(4).split(X1)) + + # Check that a warning is raised if the least populated class has too few + # members. + y = np.array([3, 3, -1, -1, 2]) + + skf_3 = StratifiedKFold(3) + assert_warns_message(Warning, "The least populated class", + next, skf_3.split(X2, y)) + + # Check that despite the warning the folds are still computed even + # though all the classes are not necessarily represented at on each + # side of the split at each split + with warnings.catch_warnings(): + check_cv_coverage(skf_3, X2, y, labels=None, expected_n_iter=3) + + # Error when number of folds is <= 1 + assert_raises(ValueError, KFold, 0) + assert_raises(ValueError, KFold, 1) + assert_raises(ValueError, StratifiedKFold, 0) + assert_raises(ValueError, StratifiedKFold, 1) + + # When n_folds is not integer: + assert_raises(ValueError, KFold, 1.5) + assert_raises(ValueError, KFold, 2.0) + assert_raises(ValueError, StratifiedKFold, 1.5) + assert_raises(ValueError, StratifiedKFold, 2.0) + + # When shuffle is not a bool: + assert_raises(TypeError, KFold, n_folds=4, shuffle=None) + + +def test_kfold_indices(): + # Check all indices are returned in the test folds + X1 = np.ones(18) + kf = KFold(3) + check_cv_coverage(kf, X1, y=None, labels=None, expected_n_iter=3) + + # Check all indices are returned in the test folds even when equal-sized + # folds are not possible + X2 = np.ones(17) + kf = KFold(3) + check_cv_coverage(kf, X2, y=None, labels=None, expected_n_iter=3) + + # Check if get_n_splits returns the number of folds + assert_equal(5, KFold(5).get_n_splits(X2)) + + +def test_kfold_no_shuffle(): + # Manually check that KFold preserves the data ordering on toy datasets + X2 = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + + splits = KFold(2).split(X2[:-1]) + train, test = next(splits) + assert_array_equal(test, [0, 1]) + assert_array_equal(train, [2, 3]) + + train, test = next(splits) + assert_array_equal(test, [2, 3]) + assert_array_equal(train, [0, 1]) + + splits = KFold(2).split(X2) + train, test = next(splits) + assert_array_equal(test, [0, 1, 2]) + assert_array_equal(train, [3, 4]) + + train, test = next(splits) + assert_array_equal(test, [3, 4]) + assert_array_equal(train, [0, 1, 2]) + + +def test_stratified_kfold_no_shuffle(): + # Manually check that StratifiedKFold preserves the data ordering as much + # as possible on toy datasets in order to avoid hiding sample dependencies + # when possible + X, y = np.ones(4), [1, 1, 0, 0] + splits = StratifiedKFold(2).split(X, y) + train, test = next(splits) + assert_array_equal(test, [0, 2]) + assert_array_equal(train, [1, 3]) + + train, test = next(splits) + assert_array_equal(test, [1, 3]) + assert_array_equal(train, [0, 2]) + + X, y = np.ones(7), [1, 1, 1, 0, 0, 0, 0] + splits = StratifiedKFold(2).split(X, y) + train, test = next(splits) + assert_array_equal(test, [0, 1, 3, 4]) + assert_array_equal(train, [2, 5, 6]) + + train, test = next(splits) + assert_array_equal(test, [2, 5, 6]) + assert_array_equal(train, [0, 1, 3, 4]) + + # Check if get_n_splits returns the number of folds + assert_equal(5, StratifiedKFold(5).get_n_splits(X, y)) + + +def test_stratified_kfold_ratios(): + # Check that stratified kfold preserves class ratios in individual splits + # Repeat with shuffling turned off and on + n_samples = 1000 + X = np.ones(n_samples) + y = np.array([4] * int(0.10 * n_samples) + + [0] * int(0.89 * n_samples) + + [1] * int(0.01 * n_samples)) + + for shuffle in (False, True): + for train, test in StratifiedKFold(5, shuffle=shuffle).split(X, y): + assert_almost_equal(np.sum(y[train] == 4) / len(train), 0.10, 2) + assert_almost_equal(np.sum(y[train] == 0) / len(train), 0.89, 2) + assert_almost_equal(np.sum(y[train] == 1) / len(train), 0.01, 2) + assert_almost_equal(np.sum(y[test] == 4) / len(test), 0.10, 2) + assert_almost_equal(np.sum(y[test] == 0) / len(test), 0.89, 2) + assert_almost_equal(np.sum(y[test] == 1) / len(test), 0.01, 2) + + +def test_kfold_balance(): + # Check that KFold returns folds with balanced sizes + for i in range(11, 17): + kf = KFold(5).split(X=np.ones(i)) + sizes = [] + for _, test in kf: + sizes.append(len(test)) + + assert_true((np.max(sizes) - np.min(sizes)) <= 1) + assert_equal(np.sum(sizes), i) + + +def test_stratifiedkfold_balance(): + # Check that KFold returns folds with balanced sizes (only when + # stratification is possible) + # Repeat with shuffling turned off and on + X = np.ones(17) + y = [0] * 3 + [1] * 14 + + for shuffle in (True, False): + cv = StratifiedKFold(3, shuffle=shuffle) + for i in range(11, 17): + skf = cv.split(X[:i], y[:i]) + sizes = [] + for _, test in skf: + sizes.append(len(test)) + + assert_true((np.max(sizes) - np.min(sizes)) <= 1) + assert_equal(np.sum(sizes), i) + + +def test_shuffle_kfold(): + # Check the indices are shuffled properly + kf = KFold(3) + kf2 = KFold(3, shuffle=True, random_state=0) + kf3 = KFold(3, shuffle=True, random_state=1) + + X = np.ones(300) + + all_folds = np.zeros(300) + for (tr1, te1), (tr2, te2), (tr3, te3) in zip( + kf.split(X), kf2.split(X), kf3.split(X)): + for tr_a, tr_b in combinations((tr1, tr2, tr3), 2): + # Assert that there is no complete overlap + assert_not_equal(len(np.intersect1d(tr_a, tr_b)), len(tr1)) + + # Set all test indices in successive iterations of kf2 to 1 + all_folds[te2] = 1 + + # Check that all indices are returned in the different test folds + assert_equal(sum(all_folds), 300) + + +def test_shuffle_kfold_stratifiedkfold_reproducibility(): + # Check that when the shuffle is True multiple split calls produce the + # same split when random_state is set + X = np.ones(15) # Divisible by 3 + y = [0] * 7 + [1] * 8 + X2 = np.ones(16) # Not divisible by 3 + y2 = [0] * 8 + [1] * 8 + + kf = KFold(3, shuffle=True, random_state=0) + skf = StratifiedKFold(3, shuffle=True, random_state=0) + + for cv in (kf, skf): + np.testing.assert_equal(list(cv.split(X, y)), list(cv.split(X, y))) + np.testing.assert_equal(list(cv.split(X2, y2)), list(cv.split(X2, y2))) + + kf = KFold(3, shuffle=True) + skf = StratifiedKFold(3, shuffle=True) + + for cv in (kf, skf): + for data in zip((X, X2), (y, y2)): + try: + np.testing.assert_equal(list(cv.split(*data)), + list(cv.split(*data))) + except AssertionError: + pass + else: + raise AssertionError("The splits for data, %s, are same even " + "when random state is not set" % data) + + +def test_shuffle_stratifiedkfold(): + # Check that shuffling is happening when requested, and for proper + # sample coverage + X_40 = np.ones(40) + y = [0] * 20 + [1] * 20 + kf0 = StratifiedKFold(5, shuffle=True, random_state=0) + kf1 = StratifiedKFold(5, shuffle=True, random_state=1) + for (_, test0), (_, test1) in zip(kf0.split(X_40, y), + kf1.split(X_40, y)): + assert_not_equal(set(test0), set(test1)) + check_cv_coverage(kf0, X_40, y, labels=None, expected_n_iter=5) + + +def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372 + # The digits samples are dependent: they are apparently grouped by authors + # although we don't have any information on the groups segment locations + # for this data. We can highlight this fact be computing k-fold cross- + # validation with and without shuffling: we observe that the shuffling case + # wrongly makes the IID assumption and is therefore too optimistic: it + # estimates a much higher accuracy (around 0.96) than than the non + # shuffling variant (around 0.86). + + X, y = digits.data[:800], digits.target[:800] + model = SVC(C=10, gamma=0.005) + + cv = KFold(n_folds=5, shuffle=False) + mean_score = cross_val_score(model, X, y, cv=cv).mean() + assert_greater(0.88, mean_score) + assert_greater(mean_score, 0.85) + + # Shuffling the data artificially breaks the dependency and hides the + # overfitting of the model with regards to the writing style of the authors + # by yielding a seriously overestimated score: + + cv = KFold(5, shuffle=True, random_state=0) + mean_score = cross_val_score(model, X, y, cv=cv).mean() + assert_greater(mean_score, 0.95) + + cv = KFold(5, shuffle=True, random_state=1) + mean_score = cross_val_score(model, X, y, cv=cv).mean() + assert_greater(mean_score, 0.95) + + # Similarly, StratifiedKFold should try to shuffle the data as little + # as possible (while respecting the balanced class constraints) + # and thus be able to detect the dependency by not overestimating + # the CV score either. As the digits dataset is approximately balanced + # the estimated mean score is close to the score measured with + # non-shuffled KFold + + cv = StratifiedKFold(5) + mean_score = cross_val_score(model, X, y, cv=cv).mean() + assert_greater(0.88, mean_score) + assert_greater(mean_score, 0.85) + + +def test_shuffle_split(): + ss1 = ShuffleSplit(test_size=0.2, random_state=0).split(X) + ss2 = ShuffleSplit(test_size=2, random_state=0).split(X) + ss3 = ShuffleSplit(test_size=np.int32(2), random_state=0).split(X) + for typ in six.integer_types: + ss4 = ShuffleSplit(test_size=typ(2), random_state=0).split(X) + for t1, t2, t3, t4 in zip(ss1, ss2, ss3, ss4): + assert_array_equal(t1[0], t2[0]) + assert_array_equal(t2[0], t3[0]) + assert_array_equal(t3[0], t4[0]) + assert_array_equal(t1[1], t2[1]) + assert_array_equal(t2[1], t3[1]) + assert_array_equal(t3[1], t4[1]) + + +def test_stratified_shuffle_split_init(): + X = np.arange(7) + y = np.asarray([0, 1, 1, 1, 2, 2, 2]) + # Check that error is raised if there is a class with only one sample + assert_raises(ValueError, next, + StratifiedShuffleSplit(3, 0.2).split(X, y)) + + # Check that error is raised if the test set size is smaller than n_classes + assert_raises(ValueError, next, StratifiedShuffleSplit(3, 2).split(X, y)) + # Check that error is raised if the train set size is smaller than + # n_classes + assert_raises(ValueError, next, + StratifiedShuffleSplit(3, 3, 2).split(X, y)) + + X = np.arange(9) + y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2]) + # Check that errors are raised if there is not enough samples + assert_raises(ValueError, StratifiedShuffleSplit, 3, 0.5, 0.6) + assert_raises(ValueError, next, + StratifiedShuffleSplit(3, 8, 0.6).split(X, y)) + assert_raises(ValueError, next, + StratifiedShuffleSplit(3, 0.6, 8).split(X, y)) + + # Train size or test size too small + assert_raises(ValueError, next, + StratifiedShuffleSplit(train_size=2).split(X, y)) + assert_raises(ValueError, next, + StratifiedShuffleSplit(test_size=2).split(X, y)) + + +def test_stratified_shuffle_split_iter(): + ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]), + np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]), + np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]), + np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]), + np.array([-1] * 800 + [1] * 50) + ] + + for y in ys: + sss = StratifiedShuffleSplit(6, test_size=0.33, + random_state=0).split(np.ones(len(y)), y) + for train, test in sss: + assert_array_equal(np.unique(y[train]), np.unique(y[test])) + # Checks if folds keep classes proportions + p_train = (np.bincount(np.unique(y[train], return_inverse=True)[1]) + / float(len(y[train]))) + p_test = (np.bincount(np.unique(y[test], return_inverse=True)[1]) + / float(len(y[test]))) + assert_array_almost_equal(p_train, p_test, 1) + assert_equal(y[train].size + y[test].size, y.size) + assert_array_equal(np.lib.arraysetops.intersect1d(train, test), []) + + +def test_stratified_shuffle_split_even(): + # Test the StratifiedShuffleSplit, indices are drawn with a + # equal chance + n_folds = 5 + n_iter = 1000 + + def assert_counts_are_ok(idx_counts, p): + # Here we test that the distribution of the counts + # per index is close enough to a binomial + threshold = 0.05 / n_splits + bf = stats.binom(n_splits, p) + for count in idx_counts: + p = bf.pmf(count) + assert_true(p > threshold, + "An index is not drawn with chance corresponding " + "to even draws") + + for n_samples in (6, 22): + labels = np.array((n_samples // 2) * [0, 1]) + splits = StratifiedShuffleSplit(n_iter=n_iter, + test_size=1. / n_folds, + random_state=0) + + train_counts = [0] * n_samples + test_counts = [0] * n_samples + n_splits = 0 + for train, test in splits.split(X=np.ones(n_samples), y=labels): + n_splits += 1 + for counter, ids in [(train_counts, train), (test_counts, test)]: + for id in ids: + counter[id] += 1 + assert_equal(n_splits, n_iter) + + n_train, n_test = _validate_shuffle_split(n_samples, + test_size=1./n_folds, + train_size=1.-(1./n_folds)) + + assert_equal(len(train), n_train) + assert_equal(len(test), n_test) + assert_equal(len(set(train).intersection(test)), 0) + + label_counts = np.unique(labels) + assert_equal(splits.test_size, 1.0 / n_folds) + assert_equal(n_train + n_test, len(labels)) + assert_equal(len(label_counts), 2) + ex_test_p = float(n_test) / n_samples + ex_train_p = float(n_train) / n_samples + + assert_counts_are_ok(train_counts, ex_train_p) + assert_counts_are_ok(test_counts, ex_test_p) + + +def test_predefinedsplit_with_kfold_split(): + # Check that PredefinedSplit can reproduce a split generated by Kfold. + folds = -1 * np.ones(10) + kf_train = [] + kf_test = [] + for i, (train_ind, test_ind) in enumerate(KFold(5, shuffle=True).split(X)): + kf_train.append(train_ind) + kf_test.append(test_ind) + folds[test_ind] = i + ps_train = [] + ps_test = [] + ps = PredefinedSplit(folds) + # n_splits is simply the no of unique folds + assert_equal(len(np.unique(folds)), ps.get_n_splits()) + for train_ind, test_ind in ps.split(): + ps_train.append(train_ind) + ps_test.append(test_ind) + assert_array_equal(ps_train, kf_train) + assert_array_equal(ps_test, kf_test) + + +def test_label_shuffle_split(): + labels = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]), + np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]), + np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]), + np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4])] + + for l in labels: + X = y = np.ones(len(l)) + n_iter = 6 + test_size = 1./3 + slo = LabelShuffleSplit(n_iter, test_size=test_size, random_state=0) + + # Make sure the repr works + repr(slo) + + # Test that the length is correct + assert_equal(slo.get_n_splits(X, y, labels=l), n_iter) + + l_unique = np.unique(l) + + for train, test in slo.split(X, y, labels=l): + # First test: no train label is in the test set and vice versa + l_train_unique = np.unique(l[train]) + l_test_unique = np.unique(l[test]) + assert_false(np.any(np.in1d(l[train], l_test_unique))) + assert_false(np.any(np.in1d(l[test], l_train_unique))) + + # Second test: train and test add up to all the data + assert_equal(l[train].size + l[test].size, l.size) + + # Third test: train and test are disjoint + assert_array_equal(np.intersect1d(train, test), []) + + # Fourth test: + # unique train and test labels are correct, +- 1 for rounding error + assert_true(abs(len(l_test_unique) - + round(test_size * len(l_unique))) <= 1) + assert_true(abs(len(l_train_unique) - + round((1.0 - test_size) * len(l_unique))) <= 1) + + +def test_leave_label_out_changing_labels(): + # Check that LeaveOneLabelOut and LeavePLabelOut work normally if + # the labels variable is changed before calling split + labels = np.array([0, 1, 2, 1, 1, 2, 0, 0]) + X = np.ones(len(labels)) + labels_changing = np.array(labels, copy=True) + lolo = LeaveOneLabelOut().split(X, labels=labels) + lolo_changing = LeaveOneLabelOut().split(X, labels=labels) + lplo = LeavePLabelOut(n_labels=2).split(X, labels=labels) + lplo_changing = LeavePLabelOut(n_labels=2).split(X, labels=labels) + labels_changing[:] = 0 + for llo, llo_changing in [(lolo, lolo_changing), (lplo, lplo_changing)]: + for (train, test), (train_chan, test_chan) in zip(llo, llo_changing): + assert_array_equal(train, train_chan) + assert_array_equal(test, test_chan) + + # n_splits = no of 2 (p) label combinations of the unique labels = 3C2 = 3 + assert_equal(3, LeavePLabelOut(n_labels=2).get_n_splits(X, y, labels)) + # n_splits = no of unique labels (C(uniq_lbls, 1) = n_unique_labels) + assert_equal(3, LeaveOneLabelOut().get_n_splits(X, y, labels)) + + +def test_train_test_split_errors(): + assert_raises(ValueError, train_test_split) + assert_raises(ValueError, train_test_split, range(3), train_size=1.1) + assert_raises(ValueError, train_test_split, range(3), test_size=0.6, + train_size=0.6) + assert_raises(ValueError, train_test_split, range(3), + test_size=np.float32(0.6), train_size=np.float32(0.6)) + assert_raises(ValueError, train_test_split, range(3), + test_size="wrong_type") + assert_raises(ValueError, train_test_split, range(3), test_size=2, + train_size=4) + assert_raises(TypeError, train_test_split, range(3), + some_argument=1.1) + assert_raises(ValueError, train_test_split, range(3), range(42)) + + +def test_train_test_split(): + X = np.arange(100).reshape((10, 10)) + X_s = coo_matrix(X) + y = np.arange(10) + + # simple test + split = train_test_split(X, y, test_size=None, train_size=.5) + X_train, X_test, y_train, y_test = split + assert_equal(len(y_test), len(y_train)) + # test correspondence of X and y + assert_array_equal(X_train[:, 0], y_train * 10) + assert_array_equal(X_test[:, 0], y_test * 10) + + # don't convert lists to anything else by default + split = train_test_split(X, X_s, y.tolist()) + X_train, X_test, X_s_train, X_s_test, y_train, y_test = split + assert_true(isinstance(y_train, list)) + assert_true(isinstance(y_test, list)) + + # allow nd-arrays + X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2) + y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11) + split = train_test_split(X_4d, y_3d) + assert_equal(split[0].shape, (7, 5, 3, 2)) + assert_equal(split[1].shape, (3, 5, 3, 2)) + assert_equal(split[2].shape, (7, 7, 11)) + assert_equal(split[3].shape, (3, 7, 11)) + + # test stratification option + y = np.array([1, 1, 1, 1, 2, 2, 2, 2]) + for test_size, exp_test_size in zip([2, 4, 0.25, 0.5, 0.75], + [2, 4, 2, 4, 6]): + train, test = train_test_split(y, test_size=test_size, + stratify=y, + random_state=0) + assert_equal(len(test), exp_test_size) + assert_equal(len(test) + len(train), len(y)) + # check the 1:1 ratio of ones and twos in the data is preserved + assert_equal(np.sum(train == 1), np.sum(train == 2)) + + +@ignore_warnings +def train_test_split_pandas(): + # check train_test_split doesn't destroy pandas dataframe + types = [MockDataFrame] + try: + from pandas import DataFrame + types.append(DataFrame) + except ImportError: + pass + for InputFeatureType in types: + # X dataframe + X_df = InputFeatureType(X) + X_train, X_test = train_test_split(X_df) + assert_true(isinstance(X_train, InputFeatureType)) + assert_true(isinstance(X_test, InputFeatureType)) + + +def train_test_split_mock_pandas(): + # X mock dataframe + X_df = MockDataFrame(X) + X_train, X_test = train_test_split(X_df) + assert_true(isinstance(X_train, MockDataFrame)) + assert_true(isinstance(X_test, MockDataFrame)) + X_train_arr, X_test_arr = train_test_split(X_df) + + +def test_shufflesplit_errors(): + # When the {test|train}_size is a float/invalid, error is raised at init + assert_raises(ValueError, ShuffleSplit, test_size=None, train_size=None) + assert_raises(ValueError, ShuffleSplit, test_size=2.0) + assert_raises(ValueError, ShuffleSplit, test_size=1.0) + assert_raises(ValueError, ShuffleSplit, test_size=0.1, train_size=0.95) + assert_raises(ValueError, ShuffleSplit, train_size=1j) + + # When the {test|train}_size is an int, validation is based on the input X + # and happens at split(...) + assert_raises(ValueError, next, ShuffleSplit(test_size=11).split(X)) + assert_raises(ValueError, next, ShuffleSplit(test_size=10).split(X)) + assert_raises(ValueError, next, ShuffleSplit(test_size=8, + train_size=3).split(X)) + + +def test_shufflesplit_reproducible(): + # Check that iterating twice on the ShuffleSplit gives the same + # sequence of train-test when the random_state is given + ss = ShuffleSplit(random_state=21) + assert_array_equal(list(a for a, b in ss.split(X)), + list(a for a, b in ss.split(X))) + + +def test_safe_split_with_precomputed_kernel(): + clf = SVC() + clfp = SVC(kernel="precomputed") + + X, y = iris.data, iris.target + K = np.dot(X, X.T) + + cv = ShuffleSplit(test_size=0.25, random_state=0) + tr, te = list(cv.split(X))[0] + + X_tr, y_tr = _safe_split(clf, X, y, tr) + K_tr, y_tr2 = _safe_split(clfp, K, y, tr) + assert_array_almost_equal(K_tr, np.dot(X_tr, X_tr.T)) + + X_te, y_te = _safe_split(clf, X, y, te, tr) + K_te, y_te2 = _safe_split(clfp, K, y, te, tr) + assert_array_almost_equal(K_te, np.dot(X_te, X_tr.T)) + + +def test_train_test_split_allow_nans(): + # Check that train_test_split allows input data with NaNs + X = np.arange(200, dtype=np.float64).reshape(10, -1) + X[2, :] = np.nan + y = np.repeat([0, 1], X.shape[0] / 2) + train_test_split(X, y, test_size=0.2, random_state=42) + + +def test_check_cv(): + X = np.ones(9) + cv = check_cv(3, classifier=False) + # Use numpy.testing.assert_equal which recursively compares + # lists of lists + np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X))) + + y_binary = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1]) + cv = check_cv(3, y_binary, classifier=True) + np.testing.assert_equal(list(StratifiedKFold(3).split(X, y_binary)), + list(cv.split(X, y_binary))) + + y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2]) + cv = check_cv(3, y_multiclass, classifier=True) + np.testing.assert_equal(list(StratifiedKFold(3).split(X, y_multiclass)), + list(cv.split(X, y_multiclass))) + + X = np.ones(5) + y_multilabel = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1], + [1, 1, 0, 1], [0, 0, 1, 0]]) + cv = check_cv(3, y_multilabel, classifier=True) + np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X))) + + y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]]) + cv = check_cv(3, y_multioutput, classifier=True) + np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X))) + + # Check if the old style classes are wrapped to have a split method + X = np.ones(9) + y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2]) + cv1 = check_cv(3, y_multiclass, classifier=True) + + with warnings.catch_warnings(record=True): + from sklearn.cross_validation import StratifiedKFold as OldSKF + + cv2 = check_cv(OldSKF(y_multiclass, n_folds=3)) + np.testing.assert_equal(list(cv1.split(X, y_multiclass)), + list(cv2.split())) + + assert_raises(ValueError, check_cv, cv="lolo") + + +def test_cv_iterable_wrapper(): + y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2]) + + with warnings.catch_warnings(record=True): + from sklearn.cross_validation import StratifiedKFold as OldSKF + + cv = OldSKF(y_multiclass, n_folds=3) + wrapped_old_skf = _CVIterableWrapper(cv) + + # Check if split works correctly + np.testing.assert_equal(list(cv), list(wrapped_old_skf.split())) + + # Check if get_n_splits works correctly + assert_equal(len(cv), wrapped_old_skf.get_n_splits()) + + +def test_label_kfold(): + rng = np.random.RandomState(0) + + # Parameters of the test + n_labels = 15 + n_samples = 1000 + n_folds = 5 + + X = y = np.ones(n_samples) + + # Construct the test data + tolerance = 0.05 * n_samples # 5 percent error allowed + labels = rng.randint(0, n_labels, n_samples) + + ideal_n_labels_per_fold = n_samples // n_folds + + len(np.unique(labels)) + # Get the test fold indices from the test set indices of each fold + folds = np.zeros(n_samples) + lkf = LabelKFold(n_folds=n_folds) + for i, (_, test) in enumerate(lkf.split(X, y, labels)): + folds[test] = i + + # Check that folds have approximately the same size + assert_equal(len(folds), len(labels)) + for i in np.unique(folds): + assert_greater_equal(tolerance, + abs(sum(folds == i) - ideal_n_labels_per_fold)) + + # Check that each label appears only in 1 fold + for label in np.unique(labels): + assert_equal(len(np.unique(folds[labels == label])), 1) + + # Check that no label is on both sides of the split + labels = np.asarray(labels, dtype=object) + for train, test in lkf.split(X, y, labels): + assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) + + # Construct the test data + labels = ['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean', + 'Francis', 'Robert', 'Michel', 'Rachel', 'Lois', + 'Michelle', 'Bernard', 'Marion', 'Laura', 'Jean', + 'Rachel', 'Franck', 'John', 'Gael', 'Anna', 'Alix', + 'Robert', 'Marion', 'David', 'Tony', 'Abel', 'Becky', + 'Madmood', 'Cary', 'Mary', 'Alexandre', 'David', 'Francis', + 'Barack', 'Abdoul', 'Rasha', 'Xi', 'Silvia'] + + n_labels = len(np.unique(labels)) + n_samples = len(labels) + n_folds = 5 + tolerance = 0.05 * n_samples # 5 percent error allowed + ideal_n_labels_per_fold = n_samples // n_folds + + X = y = np.ones(n_samples) + + # Get the test fold indices from the test set indices of each fold + folds = np.zeros(n_samples) + for i, (_, test) in enumerate(lkf.split(X, y, labels)): + folds[test] = i + + # Check that folds have approximately the same size + assert_equal(len(folds), len(labels)) + for i in np.unique(folds): + assert_greater_equal(tolerance, + abs(sum(folds == i) - ideal_n_labels_per_fold)) + + # Check that each label appears only in 1 fold + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + for label in np.unique(labels): + assert_equal(len(np.unique(folds[labels == label])), 1) + + # Check that no label is on both sides of the split + labels = np.asarray(labels, dtype=object) + for train, test in lkf.split(X, y, labels): + assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) + + # Should fail if there are more folds than labels + labels = np.array([1, 1, 1, 2, 2]) + X = y = np.ones(len(labels)) + assert_raises_regexp(ValueError, "Cannot have number of folds.*greater", + next, LabelKFold(n_folds=3).split(X, y, labels)) + + +def test_nested_cv(): + # Test if nested cross validation works with different combinations of cv + rng = np.random.RandomState(0) + + X, y = make_classification(n_samples=15, n_classes=2, random_state=0) + labels = rng.randint(0, 5, 15) + + cvs = [LeaveOneLabelOut(), LeaveOneOut(), LabelKFold(), StratifiedKFold(), + StratifiedShuffleSplit(n_iter=10, random_state=0)] + + for inner_cv, outer_cv in combinations_with_replacement(cvs, 2): + gs = GridSearchCV(LinearSVC(random_state=0), param_grid={'C': [1, 10]}, + cv=inner_cv) + cross_val_score(gs, X=X, y=y, labels=labels, cv=outer_cv, + fit_params={'labels': labels}) + + +def test_build_repr(): + class MockSplitter: + def __init__(self, a, b=0, c=None): + self.a = a + self.b = b + self.c = c + + def __repr__(self): + return _build_repr(self) + + assert_equal(repr(MockSplitter(5, 6)), "MockSplitter(a=5, b=6, c=None)") diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py new file mode 100644 index 0000000000000..138d78ea192d1 --- /dev/null +++ b/sklearn/model_selection/tests/test_validation.py @@ -0,0 +1,733 @@ +"""Test the validation module""" +from __future__ import division + +import sys +import warnings + +import numpy as np +from scipy.sparse import coo_matrix, csr_matrix + +from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_false +from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_less +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_warns +from sklearn.utils.mocking import CheckingClassifier, MockDataFrame + +from sklearn.model_selection import cross_val_score +from sklearn.model_selection import cross_val_predict +from sklearn.model_selection import permutation_test_score +from sklearn.model_selection import KFold +from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import LeaveOneOut +from sklearn.model_selection import LeaveOneLabelOut +from sklearn.model_selection import LeavePLabelOut +from sklearn.model_selection import LabelKFold +from sklearn.model_selection import LabelShuffleSplit +from sklearn.model_selection import learning_curve +from sklearn.model_selection import validation_curve +from sklearn.model_selection._validation import _check_is_permutation + +from sklearn.datasets import make_regression +from sklearn.datasets import load_boston +from sklearn.datasets import load_iris +from sklearn.metrics import explained_variance_score +from sklearn.metrics import make_scorer +from sklearn.metrics import precision_score + +from sklearn.linear_model import Ridge +from sklearn.linear_model import PassiveAggressiveClassifier +from sklearn.neighbors import KNeighborsClassifier +from sklearn.svm import SVC +from sklearn.cluster import KMeans + +from sklearn.preprocessing import Imputer +from sklearn.pipeline import Pipeline + +from sklearn.externals.six.moves import cStringIO as StringIO +from sklearn.base import BaseEstimator +from sklearn.multiclass import OneVsRestClassifier +from sklearn.datasets import make_classification +from sklearn.datasets import make_multilabel_classification + +from test_split import MockClassifier + + +class MockImprovingEstimator(BaseEstimator): + """Dummy classifier to test the learning curve""" + def __init__(self, n_max_train_sizes): + self.n_max_train_sizes = n_max_train_sizes + self.train_sizes = 0 + self.X_subset = None + + def fit(self, X_subset, y_subset=None): + self.X_subset = X_subset + self.train_sizes = X_subset.shape[0] + return self + + def predict(self, X): + raise NotImplementedError + + def score(self, X=None, Y=None): + # training score becomes worse (2 -> 1), test error better (0 -> 1) + if self._is_training_data(X): + return 2. - float(self.train_sizes) / self.n_max_train_sizes + else: + return float(self.train_sizes) / self.n_max_train_sizes + + def _is_training_data(self, X): + return X is self.X_subset + + +class MockIncrementalImprovingEstimator(MockImprovingEstimator): + """Dummy classifier that provides partial_fit""" + def __init__(self, n_max_train_sizes): + super(MockIncrementalImprovingEstimator, + self).__init__(n_max_train_sizes) + self.x = None + + def _is_training_data(self, X): + return self.x in X + + def partial_fit(self, X, y=None, **params): + self.train_sizes += X.shape[0] + self.x = X[0] + + +class MockEstimatorWithParameter(BaseEstimator): + """Dummy classifier to test the validation curve""" + def __init__(self, param=0.5): + self.X_subset = None + self.param = param + + def fit(self, X_subset, y_subset): + self.X_subset = X_subset + self.train_sizes = X_subset.shape[0] + return self + + def predict(self, X): + raise NotImplementedError + + def score(self, X=None, y=None): + return self.param if self._is_training_data(X) else 1 - self.param + + def _is_training_data(self, X): + return X is self.X_subset + + +# XXX: use 2D array, since 1D X is being detected as a single sample in +# check_consistent_length +X = np.ones((10, 2)) +X_sparse = coo_matrix(X) +y = np.arange(10) // 2 + + +def test_cross_val_score(): + clf = MockClassifier() + + for a in range(-10, 10): + clf.a = a + # Smoke test + scores = cross_val_score(clf, X, y) + assert_array_equal(scores, clf.score(X, y)) + + # test with multioutput y + scores = cross_val_score(clf, X_sparse, X) + assert_array_equal(scores, clf.score(X_sparse, X)) + + scores = cross_val_score(clf, X_sparse, y) + assert_array_equal(scores, clf.score(X_sparse, y)) + + # test with multioutput y + scores = cross_val_score(clf, X_sparse, X) + assert_array_equal(scores, clf.score(X_sparse, X)) + + # test with X and y as list + list_check = lambda x: isinstance(x, list) + clf = CheckingClassifier(check_X=list_check) + scores = cross_val_score(clf, X.tolist(), y.tolist()) + + clf = CheckingClassifier(check_y=list_check) + scores = cross_val_score(clf, X, y.tolist()) + + assert_raises(ValueError, cross_val_score, clf, X, y, + scoring="sklearn") + + # test with 3d X and + X_3d = X[:, :, np.newaxis] + clf = MockClassifier(allow_nd=True) + scores = cross_val_score(clf, X_3d, y) + + clf = MockClassifier(allow_nd=False) + assert_raises(ValueError, cross_val_score, clf, X_3d, y) + + +def test_cross_val_score_predict_labels(): + # Check if ValueError (when labels is None) propagates to cross_val_score + # and cross_val_predict + # And also check if labels is correctly passed to the cv object + X, y = make_classification(n_samples=20, n_classes=2, random_state=0) + + clf = SVC(kernel="linear") + + label_cvs = [LeaveOneLabelOut(), LeavePLabelOut(2), LabelKFold(), + LabelShuffleSplit()] + for cv in label_cvs: + assert_raise_message(ValueError, + "The labels parameter should not be None", + cross_val_score, estimator=clf, X=X, y=y, cv=cv) + assert_raise_message(ValueError, + "The labels parameter should not be None", + cross_val_predict, estimator=clf, X=X, y=y, cv=cv) + + +def test_cross_val_score_pandas(): + # check cross_val_score doesn't destroy pandas dataframe + types = [(MockDataFrame, MockDataFrame)] + try: + from pandas import Series, DataFrame + types.append((Series, DataFrame)) + except ImportError: + pass + for TargetType, InputFeatureType in types: + # X dataframe, y series + X_df, y_ser = InputFeatureType(X), TargetType(y) + check_df = lambda x: isinstance(x, InputFeatureType) + check_series = lambda x: isinstance(x, TargetType) + clf = CheckingClassifier(check_X=check_df, check_y=check_series) + cross_val_score(clf, X_df, y_ser) + + +def test_cross_val_score_mask(): + # test that cross_val_score works with boolean masks + svm = SVC(kernel="linear") + iris = load_iris() + X, y = iris.data, iris.target + kfold = KFold(5) + scores_indices = cross_val_score(svm, X, y, cv=kfold) + kfold = KFold(5) + cv_masks = [] + for train, test in kfold.split(X, y): + mask_train = np.zeros(len(y), dtype=np.bool) + mask_test = np.zeros(len(y), dtype=np.bool) + mask_train[train] = 1 + mask_test[test] = 1 + cv_masks.append((train, test)) + scores_masks = cross_val_score(svm, X, y, cv=cv_masks) + assert_array_equal(scores_indices, scores_masks) + + +def test_cross_val_score_precomputed(): + # test for svm with precomputed kernel + svm = SVC(kernel="precomputed") + iris = load_iris() + X, y = iris.data, iris.target + linear_kernel = np.dot(X, X.T) + score_precomputed = cross_val_score(svm, linear_kernel, y) + svm = SVC(kernel="linear") + score_linear = cross_val_score(svm, X, y) + assert_array_equal(score_precomputed, score_linear) + + # Error raised for non-square X + svm = SVC(kernel="precomputed") + assert_raises(ValueError, cross_val_score, svm, X, y) + + # test error is raised when the precomputed kernel is not array-like + # or sparse + assert_raises(ValueError, cross_val_score, svm, + linear_kernel.tolist(), y) + + +def test_cross_val_score_fit_params(): + clf = MockClassifier() + n_samples = X.shape[0] + n_classes = len(np.unique(y)) + + W_sparse = coo_matrix((np.array([1]), (np.array([1]), np.array([0]))), + shape=(10, 1)) + P_sparse = coo_matrix(np.eye(5)) + + DUMMY_INT = 42 + DUMMY_STR = '42' + DUMMY_OBJ = object() + + def assert_fit_params(clf): + # Function to test that the values are passed correctly to the + # classifier arguments for non-array type + + assert_equal(clf.dummy_int, DUMMY_INT) + assert_equal(clf.dummy_str, DUMMY_STR) + assert_equal(clf.dummy_obj, DUMMY_OBJ) + + fit_params = {'sample_weight': np.ones(n_samples), + 'class_prior': np.ones(n_classes) / n_classes, + 'sparse_sample_weight': W_sparse, + 'sparse_param': P_sparse, + 'dummy_int': DUMMY_INT, + 'dummy_str': DUMMY_STR, + 'dummy_obj': DUMMY_OBJ, + 'callback': assert_fit_params} + cross_val_score(clf, X, y, fit_params=fit_params) + + +def test_cross_val_score_score_func(): + clf = MockClassifier() + _score_func_args = [] + + def score_func(y_test, y_predict): + _score_func_args.append((y_test, y_predict)) + return 1.0 + + with warnings.catch_warnings(record=True): + scoring = make_scorer(score_func) + score = cross_val_score(clf, X, y, scoring=scoring) + assert_array_equal(score, [1.0, 1.0, 1.0]) + assert len(_score_func_args) == 3 + + +def test_cross_val_score_errors(): + class BrokenEstimator: + pass + + assert_raises(TypeError, cross_val_score, BrokenEstimator(), X) + + +def test_cross_val_score_with_score_func_classification(): + iris = load_iris() + clf = SVC(kernel='linear') + + # Default score (should be the accuracy score) + scores = cross_val_score(clf, iris.data, iris.target, cv=5) + assert_array_almost_equal(scores, [0.97, 1., 0.97, 0.97, 1.], 2) + + # Correct classification score (aka. zero / one score) - should be the + # same as the default estimator score + zo_scores = cross_val_score(clf, iris.data, iris.target, + scoring="accuracy", cv=5) + assert_array_almost_equal(zo_scores, [0.97, 1., 0.97, 0.97, 1.], 2) + + # F1 score (class are balanced so f1_score should be equal to zero/one + # score + f1_scores = cross_val_score(clf, iris.data, iris.target, + scoring="f1_weighted", cv=5) + assert_array_almost_equal(f1_scores, [0.97, 1., 0.97, 0.97, 1.], 2) + + +def test_cross_val_score_with_score_func_regression(): + X, y = make_regression(n_samples=30, n_features=20, n_informative=5, + random_state=0) + reg = Ridge() + + # Default score of the Ridge regression estimator + scores = cross_val_score(reg, X, y, cv=5) + assert_array_almost_equal(scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2) + + # R2 score (aka. determination coefficient) - should be the + # same as the default estimator score + r2_scores = cross_val_score(reg, X, y, scoring="r2", cv=5) + assert_array_almost_equal(r2_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2) + + # Mean squared error; this is a loss function, so "scores" are negative + mse_scores = cross_val_score(reg, X, y, cv=5, scoring="mean_squared_error") + expected_mse = np.array([-763.07, -553.16, -274.38, -273.26, -1681.99]) + assert_array_almost_equal(mse_scores, expected_mse, 2) + + # Explained variance + scoring = make_scorer(explained_variance_score) + ev_scores = cross_val_score(reg, X, y, cv=5, scoring=scoring) + assert_array_almost_equal(ev_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2) + + +def test_permutation_score(): + iris = load_iris() + X = iris.data + X_sparse = coo_matrix(X) + y = iris.target + svm = SVC(kernel='linear') + cv = StratifiedKFold(2) + + score, scores, pvalue = permutation_test_score( + svm, X, y, n_permutations=30, cv=cv, scoring="accuracy") + assert_greater(score, 0.9) + assert_almost_equal(pvalue, 0.0, 1) + + score_label, _, pvalue_label = permutation_test_score( + svm, X, y, n_permutations=30, cv=cv, scoring="accuracy", + labels=np.ones(y.size), random_state=0) + assert_true(score_label == score) + assert_true(pvalue_label == pvalue) + + # check that we obtain the same results with a sparse representation + svm_sparse = SVC(kernel='linear') + cv_sparse = StratifiedKFold(2) + score_label, _, pvalue_label = permutation_test_score( + svm_sparse, X_sparse, y, n_permutations=30, cv=cv_sparse, + scoring="accuracy", labels=np.ones(y.size), random_state=0) + + assert_true(score_label == score) + assert_true(pvalue_label == pvalue) + + # test with custom scoring object + def custom_score(y_true, y_pred): + return (((y_true == y_pred).sum() - (y_true != y_pred).sum()) + / y_true.shape[0]) + + scorer = make_scorer(custom_score) + score, _, pvalue = permutation_test_score( + svm, X, y, n_permutations=100, scoring=scorer, cv=cv, random_state=0) + assert_almost_equal(score, .93, 2) + assert_almost_equal(pvalue, 0.01, 3) + + # set random y + y = np.mod(np.arange(len(y)), 3) + + score, scores, pvalue = permutation_test_score( + svm, X, y, n_permutations=30, cv=cv, scoring="accuracy") + + assert_less(score, 0.5) + assert_greater(pvalue, 0.2) + + +def test_permutation_test_score_allow_nans(): + # Check that permutation_test_score allows input data with NaNs + X = np.arange(200, dtype=np.float64).reshape(10, -1) + X[2, :] = np.nan + y = np.repeat([0, 1], X.shape[0] / 2) + p = Pipeline([ + ('imputer', Imputer(strategy='mean', missing_values='NaN')), + ('classifier', MockClassifier()), + ]) + permutation_test_score(p, X, y, cv=5) + + +def test_cross_val_score_allow_nans(): + # Check that cross_val_score allows input data with NaNs + X = np.arange(200, dtype=np.float64).reshape(10, -1) + X[2, :] = np.nan + y = np.repeat([0, 1], X.shape[0] / 2) + p = Pipeline([ + ('imputer', Imputer(strategy='mean', missing_values='NaN')), + ('classifier', MockClassifier()), + ]) + cross_val_score(p, X, y, cv=5) + + +def test_cross_val_score_multilabel(): + X = np.array([[-3, 4], [2, 4], [3, 3], [0, 2], [-3, 1], + [-2, 1], [0, 0], [-2, -1], [-1, -2], [1, -2]]) + y = np.array([[1, 1], [0, 1], [0, 1], [0, 1], [1, 1], + [0, 1], [1, 0], [1, 1], [1, 0], [0, 0]]) + clf = KNeighborsClassifier(n_neighbors=1) + scoring_micro = make_scorer(precision_score, average='micro') + scoring_macro = make_scorer(precision_score, average='macro') + scoring_samples = make_scorer(precision_score, average='samples') + score_micro = cross_val_score(clf, X, y, scoring=scoring_micro, cv=5) + score_macro = cross_val_score(clf, X, y, scoring=scoring_macro, cv=5) + score_samples = cross_val_score(clf, X, y, scoring=scoring_samples, cv=5) + assert_almost_equal(score_micro, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 3]) + assert_almost_equal(score_macro, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 4]) + assert_almost_equal(score_samples, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 4]) + + +def test_cross_val_predict(): + boston = load_boston() + X, y = boston.data, boston.target + cv = KFold() + + est = Ridge() + + # Naive loop (should be same as cross_val_predict): + preds2 = np.zeros_like(y) + for train, test in cv.split(X, y): + est.fit(X[train], y[train]) + preds2[test] = est.predict(X[test]) + + preds = cross_val_predict(est, X, y, cv=cv) + assert_array_almost_equal(preds, preds2) + + preds = cross_val_predict(est, X, y) + assert_equal(len(preds), len(y)) + + cv = LeaveOneOut() + preds = cross_val_predict(est, X, y, cv=cv) + assert_equal(len(preds), len(y)) + + Xsp = X.copy() + Xsp *= (Xsp > np.median(Xsp)) + Xsp = coo_matrix(Xsp) + preds = cross_val_predict(est, Xsp, y) + assert_array_almost_equal(len(preds), len(y)) + + preds = cross_val_predict(KMeans(), X) + assert_equal(len(preds), len(y)) + + class BadCV(): + def split(self, X, y=None, labels=None): + for i in range(4): + yield np.array([0, 1, 2, 3]), np.array([4, 5, 6, 7, 8]) + + assert_raises(ValueError, cross_val_predict, est, X, y, cv=BadCV()) + + +def test_cross_val_predict_input_types(): + clf = Ridge() + # Smoke test + predictions = cross_val_predict(clf, X, y) + assert_equal(predictions.shape, (10,)) + + # test with multioutput y + predictions = cross_val_predict(clf, X_sparse, X) + assert_equal(predictions.shape, (10, 2)) + + predictions = cross_val_predict(clf, X_sparse, y) + assert_array_equal(predictions.shape, (10,)) + + # test with multioutput y + predictions = cross_val_predict(clf, X_sparse, X) + assert_array_equal(predictions.shape, (10, 2)) + + # test with X and y as list + list_check = lambda x: isinstance(x, list) + clf = CheckingClassifier(check_X=list_check) + predictions = cross_val_predict(clf, X.tolist(), y.tolist()) + + clf = CheckingClassifier(check_y=list_check) + predictions = cross_val_predict(clf, X, y.tolist()) + + # test with 3d X and + X_3d = X[:, :, np.newaxis] + check_3d = lambda x: x.ndim == 3 + clf = CheckingClassifier(check_X=check_3d) + predictions = cross_val_predict(clf, X_3d, y) + assert_array_equal(predictions.shape, (10,)) + + +def test_cross_val_predict_pandas(): + # check cross_val_score doesn't destroy pandas dataframe + types = [(MockDataFrame, MockDataFrame)] + try: + from pandas import Series, DataFrame + types.append((Series, DataFrame)) + except ImportError: + pass + for TargetType, InputFeatureType in types: + # X dataframe, y series + X_df, y_ser = InputFeatureType(X), TargetType(y) + check_df = lambda x: isinstance(x, InputFeatureType) + check_series = lambda x: isinstance(x, TargetType) + clf = CheckingClassifier(check_X=check_df, check_y=check_series) + cross_val_predict(clf, X_df, y_ser) + + +def test_cross_val_score_sparse_fit_params(): + iris = load_iris() + X, y = iris.data, iris.target + clf = MockClassifier() + fit_params = {'sparse_sample_weight': coo_matrix(np.eye(X.shape[0]))} + a = cross_val_score(clf, X, y, fit_params=fit_params) + assert_array_equal(a, np.ones(3)) + + +def test_learning_curve(): + X, y = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockImprovingEstimator(20) + with warnings.catch_warnings(record=True) as w: + train_sizes, train_scores, test_scores = learning_curve( + estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10)) + if len(w) > 0: + raise RuntimeError("Unexpected warning: %r" % w[0].message) + assert_equal(train_scores.shape, (10, 3)) + assert_equal(test_scores.shape, (10, 3)) + assert_array_equal(train_sizes, np.linspace(2, 20, 10)) + assert_array_almost_equal(train_scores.mean(axis=1), + np.linspace(1.9, 1.0, 10)) + assert_array_almost_equal(test_scores.mean(axis=1), + np.linspace(0.1, 1.0, 10)) + + +def test_learning_curve_unsupervised(): + X, _ = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockImprovingEstimator(20) + train_sizes, train_scores, test_scores = learning_curve( + estimator, X, y=None, cv=3, train_sizes=np.linspace(0.1, 1.0, 10)) + assert_array_equal(train_sizes, np.linspace(2, 20, 10)) + assert_array_almost_equal(train_scores.mean(axis=1), + np.linspace(1.9, 1.0, 10)) + assert_array_almost_equal(test_scores.mean(axis=1), + np.linspace(0.1, 1.0, 10)) + + +def test_learning_curve_verbose(): + X, y = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockImprovingEstimator(20) + + old_stdout = sys.stdout + sys.stdout = StringIO() + try: + train_sizes, train_scores, test_scores = \ + learning_curve(estimator, X, y, cv=3, verbose=1) + finally: + out = sys.stdout.getvalue() + sys.stdout.close() + sys.stdout = old_stdout + + assert("[learning_curve]" in out) + + +def test_learning_curve_incremental_learning_not_possible(): + X, y = make_classification(n_samples=2, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + # The mockup does not have partial_fit() + estimator = MockImprovingEstimator(1) + assert_raises(ValueError, learning_curve, estimator, X, y, + exploit_incremental_learning=True) + + +def test_learning_curve_incremental_learning(): + X, y = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockIncrementalImprovingEstimator(20) + train_sizes, train_scores, test_scores = learning_curve( + estimator, X, y, cv=3, exploit_incremental_learning=True, + train_sizes=np.linspace(0.1, 1.0, 10)) + assert_array_equal(train_sizes, np.linspace(2, 20, 10)) + assert_array_almost_equal(train_scores.mean(axis=1), + np.linspace(1.9, 1.0, 10)) + assert_array_almost_equal(test_scores.mean(axis=1), + np.linspace(0.1, 1.0, 10)) + + +def test_learning_curve_incremental_learning_unsupervised(): + X, _ = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockIncrementalImprovingEstimator(20) + train_sizes, train_scores, test_scores = learning_curve( + estimator, X, y=None, cv=3, exploit_incremental_learning=True, + train_sizes=np.linspace(0.1, 1.0, 10)) + assert_array_equal(train_sizes, np.linspace(2, 20, 10)) + assert_array_almost_equal(train_scores.mean(axis=1), + np.linspace(1.9, 1.0, 10)) + assert_array_almost_equal(test_scores.mean(axis=1), + np.linspace(0.1, 1.0, 10)) + + +def test_learning_curve_batch_and_incremental_learning_are_equal(): + X, y = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + train_sizes = np.linspace(0.2, 1.0, 5) + estimator = PassiveAggressiveClassifier(n_iter=1, shuffle=False) + + train_sizes_inc, train_scores_inc, test_scores_inc = \ + learning_curve( + estimator, X, y, train_sizes=train_sizes, + cv=3, exploit_incremental_learning=True) + train_sizes_batch, train_scores_batch, test_scores_batch = \ + learning_curve( + estimator, X, y, cv=3, train_sizes=train_sizes, + exploit_incremental_learning=False) + + assert_array_equal(train_sizes_inc, train_sizes_batch) + assert_array_almost_equal(train_scores_inc.mean(axis=1), + train_scores_batch.mean(axis=1)) + assert_array_almost_equal(test_scores_inc.mean(axis=1), + test_scores_batch.mean(axis=1)) + + +def test_learning_curve_n_sample_range_out_of_bounds(): + X, y = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockImprovingEstimator(20) + assert_raises(ValueError, learning_curve, estimator, X, y, cv=3, + train_sizes=[0, 1]) + assert_raises(ValueError, learning_curve, estimator, X, y, cv=3, + train_sizes=[0.0, 1.0]) + assert_raises(ValueError, learning_curve, estimator, X, y, cv=3, + train_sizes=[0.1, 1.1]) + assert_raises(ValueError, learning_curve, estimator, X, y, cv=3, + train_sizes=[0, 20]) + assert_raises(ValueError, learning_curve, estimator, X, y, cv=3, + train_sizes=[1, 21]) + + +def test_learning_curve_remove_duplicate_sample_sizes(): + X, y = make_classification(n_samples=3, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockImprovingEstimator(2) + train_sizes, _, _ = assert_warns( + RuntimeWarning, learning_curve, estimator, X, y, cv=3, + train_sizes=np.linspace(0.33, 1.0, 3)) + assert_array_equal(train_sizes, [1, 2]) + + +def test_learning_curve_with_boolean_indices(): + X, y = make_classification(n_samples=30, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + estimator = MockImprovingEstimator(20) + cv = KFold(n_folds=3) + train_sizes, train_scores, test_scores = learning_curve( + estimator, X, y, cv=cv, train_sizes=np.linspace(0.1, 1.0, 10)) + assert_array_equal(train_sizes, np.linspace(2, 20, 10)) + assert_array_almost_equal(train_scores.mean(axis=1), + np.linspace(1.9, 1.0, 10)) + assert_array_almost_equal(test_scores.mean(axis=1), + np.linspace(0.1, 1.0, 10)) + + +def test_validation_curve(): + X, y = make_classification(n_samples=2, n_features=1, n_informative=1, + n_redundant=0, n_classes=2, + n_clusters_per_class=1, random_state=0) + param_range = np.linspace(0, 1, 10) + with warnings.catch_warnings(record=True) as w: + train_scores, test_scores = validation_curve( + MockEstimatorWithParameter(), X, y, param_name="param", + param_range=param_range, cv=2 + ) + if len(w) > 0: + raise RuntimeError("Unexpected warning: %r" % w[0].message) + + assert_array_almost_equal(train_scores.mean(axis=1), param_range) + assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range) + + +def test_check_is_permutation(): + p = np.arange(100) + assert_true(_check_is_permutation(p, 100)) + assert_false(_check_is_permutation(np.delete(p, 23), 100)) + + p[0] = 23 + assert_false(_check_is_permutation(p, 100)) + + +def test_cross_val_predict_sparse_prediction(): + # check that cross_val_predict gives same result for sparse and dense input + X, y = make_multilabel_classification(n_classes=2, n_labels=1, + allow_unlabeled=False, + return_indicator=True, + random_state=1) + X_sparse = csr_matrix(X) + y_sparse = csr_matrix(y) + classif = OneVsRestClassifier(SVC(kernel='linear')) + preds = cross_val_predict(classif, X, y, cv=10) + preds_sparse = cross_val_predict(classif, X_sparse, y_sparse, cv=10) + preds_sparse = preds_sparse.toarray() + assert_array_almost_equal(preds_sparse, preds) diff --git a/sklearn/neighbors/tests/test_kde.py b/sklearn/neighbors/tests/test_kde.py index 69bac3d7442b9..3078a3c05df39 100644 --- a/sklearn/neighbors/tests/test_kde.py +++ b/sklearn/neighbors/tests/test_kde.py @@ -5,7 +5,7 @@ from sklearn.neighbors.ball_tree import kernel_norm from sklearn.pipeline import make_pipeline from sklearn.datasets import make_blobs -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import GridSearchCV from sklearn.preprocessing import StandardScaler diff --git a/sklearn/neighbors/tests/test_neighbors.py b/sklearn/neighbors/tests/test_neighbors.py index d7549dc30687e..35eb1edf6fd0a 100644 --- a/sklearn/neighbors/tests/test_neighbors.py +++ b/sklearn/neighbors/tests/test_neighbors.py @@ -6,7 +6,8 @@ dok_matrix, lil_matrix) from sklearn import metrics -from sklearn.cross_validation import train_test_split, cross_val_score +from sklearn.model_selection import train_test_split +from sklearn.model_selection import cross_val_score from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_raises diff --git a/sklearn/preprocessing/tests/test_imputation.py b/sklearn/preprocessing/tests/test_imputation.py index 84df143db0d2c..028f5603d1bbd 100644 --- a/sklearn/preprocessing/tests/test_imputation.py +++ b/sklearn/preprocessing/tests/test_imputation.py @@ -10,7 +10,7 @@ from sklearn.preprocessing.imputation import Imputer from sklearn.pipeline import Pipeline -from sklearn import grid_search +from sklearn.model_selection import GridSearchCV from sklearn import tree from sklearn.random_projection import sparse_random_matrix @@ -269,7 +269,7 @@ def test_imputation_pipeline_grid_search(): l = 100 X = sparse_random_matrix(l, l, density=0.10) Y = sparse_random_matrix(l, 1, density=0.10).toarray() - gs = grid_search.GridSearchCV(pipeline, parameters) + gs = GridSearchCV(pipeline, parameters) gs.fit(X, Y) diff --git a/sklearn/setup.py b/sklearn/setup.py index 6634e20ba33a4..6179b01744def 100644 --- a/sklearn/setup.py +++ b/sklearn/setup.py @@ -49,6 +49,8 @@ def configuration(parent_package='', top_path=None): config.add_subpackage('metrics/tests') config.add_subpackage('metrics/cluster') config.add_subpackage('metrics/cluster/tests') + config.add_subpackage('model_selection') + config.add_subpackage('model_selection/tests') # add cython extension module for isotonic regression config.add_extension( diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 49fc79a9945ca..ff61b73fd9ab6 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -13,7 +13,7 @@ from nose.tools import assert_raises, assert_true, assert_equal, assert_false from sklearn import svm, linear_model, datasets, metrics, base -from sklearn.cross_validation import train_test_split +from sklearn.model_selection import train_test_split from sklearn.datasets import make_classification, make_blobs from sklearn.metrics import f1_score from sklearn.metrics.pairwise import rbf_kernel diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index e2a06e58d9eea..59ee9620396ce 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -14,7 +14,7 @@ from sklearn.base import BaseEstimator, clone, is_classifier from sklearn.svm import SVC from sklearn.pipeline import Pipeline -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import GridSearchCV from sklearn.utils import deprecated diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index 6d3db0d120300..44b6bb8341bb4 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -22,7 +22,10 @@ from sklearn.utils.testing import ignore_warnings from sklearn.utils.mocking import CheckingClassifier, MockDataFrame -from sklearn import cross_validation as cval +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + from sklearn import cross_validation as cval + from sklearn.datasets import make_regression from sklearn.datasets import load_boston from sklearn.datasets import load_digits diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 3b3caf9d72d2f..13b9086310595 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -8,6 +8,7 @@ from sklearn.externals.six.moves import xrange from itertools import chain, product import pickle +import warnings import sys import numpy as np @@ -33,9 +34,6 @@ from sklearn.datasets import make_classification from sklearn.datasets import make_blobs from sklearn.datasets import make_multilabel_classification -from sklearn.grid_search import GridSearchCV, RandomizedSearchCV -from sklearn.grid_search import ParameterGrid, ParameterSampler -from sklearn.exceptions import ChangedBehaviorWarning from sklearn.svm import LinearSVC, SVC from sklearn.tree import DecisionTreeRegressor from sklearn.tree import DecisionTreeClassifier @@ -44,8 +42,16 @@ from sklearn.metrics import f1_score from sklearn.metrics import make_scorer from sklearn.metrics import roc_auc_score -from sklearn.cross_validation import KFold, StratifiedKFold + +from sklearn.exceptions import ChangedBehaviorWarning from sklearn.exceptions import FitFailedWarning + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + from sklearn.grid_search import (GridSearchCV, RandomizedSearchCV, + ParameterGrid, ParameterSampler) + from sklearn.cross_validation import KFold, StratifiedKFold + from sklearn.preprocessing import Imputer from sklearn.pipeline import Pipeline diff --git a/sklearn/tests/test_learning_curve.py b/sklearn/tests/test_learning_curve.py index afe8ce189d6d9..2ade20d43855c 100644 --- a/sklearn/tests/test_learning_curve.py +++ b/sklearn/tests/test_learning_curve.py @@ -7,14 +7,18 @@ import numpy as np import warnings from sklearn.base import BaseEstimator -from sklearn.learning_curve import learning_curve, validation_curve from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_warns from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.datasets import make_classification -from sklearn.cross_validation import KFold + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + from sklearn.learning_curve import learning_curve, validation_curve + from sklearn.cross_validation import KFold + from sklearn.linear_model import PassiveAggressiveClassifier diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index adc56a4fa749d..4c6ace3d3aeb8 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -9,7 +9,7 @@ from sklearn.datasets import make_classification from sklearn.utils.testing import assert_true, assert_false, assert_raises from sklearn.pipeline import Pipeline -from sklearn.grid_search import GridSearchCV, RandomizedSearchCV +from sklearn.model_selection import GridSearchCV, RandomizedSearchCV from sklearn.feature_selection import RFE, RFECV from sklearn.ensemble import BaggingClassifier diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 16588fe4010c5..792490ef3e4dd 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -21,7 +21,7 @@ from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge, Perceptron, LogisticRegression) from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor -from sklearn.grid_search import GridSearchCV +from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline from sklearn import svm from sklearn import datasets diff --git a/sklearn/tests/test_naive_bayes.py b/sklearn/tests/test_naive_bayes.py index e5c028d0cb3aa..c0086ff189116 100644 --- a/sklearn/tests/test_naive_bayes.py +++ b/sklearn/tests/test_naive_bayes.py @@ -4,7 +4,9 @@ import scipy.sparse from sklearn.datasets import load_digits, load_iris -from sklearn.cross_validation import cross_val_score, train_test_split + +from sklearn.model_selection import train_test_split +from sklearn.model_selection import cross_val_score from sklearn.externals.six.moves import zip from sklearn.utils.testing import assert_almost_equal diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 028d2626bc23c..dac669bcee8f0 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -634,7 +634,7 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): Examples -------- >>> from sklearn.datasets import load_iris - >>> from sklearn.cross_validation import cross_val_score + >>> from sklearn.model_selection import cross_val_score >>> from sklearn.tree import DecisionTreeClassifier >>> clf = DecisionTreeClassifier(random_state=0) >>> iris = load_iris() @@ -854,7 +854,7 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): Examples -------- >>> from sklearn.datasets import load_boston - >>> from sklearn.cross_validation import cross_val_score + >>> from sklearn.model_selection import cross_val_score >>> from sklearn.tree import DecisionTreeRegressor >>> boston = load_boston() >>> regressor = DecisionTreeRegressor(random_state=0) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 9b4e1ba782466..c15be3de46bbc 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -43,7 +43,7 @@ from sklearn.decomposition import NMF, ProjectedGradientNMF from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import DataConversionWarning -from sklearn.cross_validation import train_test_split +from sklearn.model_selection import train_test_split from sklearn.utils import shuffle from sklearn.utils.fixes import signature diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 30bb277da04c6..e1a98c4ab0226 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -20,6 +20,7 @@ import scipy as sp import scipy.io from functools import wraps +from operator import itemgetter try: # Python 2 from urllib2 import urlopen @@ -604,7 +605,7 @@ def is_abstract(c): path = sklearn.__path__ for importer, modname, ispkg in pkgutil.walk_packages( path=path, prefix='sklearn.', onerror=lambda x: None): - if ".tests." in modname: + if (".tests." in modname): continue module = __import__(modname, fromlist="dummy") classes = inspect.getmembers(module, inspect.isclass) @@ -648,7 +649,9 @@ def is_abstract(c): " %s." % repr(type_filter)) # drop duplicates, sort for reproducibility - return sorted(set(estimators)) + # itemgetter is used to ensure the sort does not extend to the 2nd item of + # the tuple + return sorted(set(estimators), key=itemgetter(0)) def set_random_state(estimator, random_state=0):