diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 5a28b31b33c2f..46a6c8c119f8e 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -26,6 +26,7 @@ from sklearn.linear_model.base import LinearClassifierMixin from sklearn.utils.estimator_checks import ( _yield_all_checks, + _clear_temp_memory, CROSS_DECOMPOSITION, check_parameters_default_constructible, check_class_weight_balanced_linear_classifier, @@ -73,6 +74,8 @@ def test_non_meta_estimators(): yield check, name, Estimator else: yield check, name, Estimator + _clear_temp_memory(warn=False) + def test_configure(): # Smoke test the 'configure' step of setup, this tests all the diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 4e396db755d39..8c37075c63afb 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1,10 +1,21 @@ from __future__ import print_function +import os import types import warnings import sys import traceback import pickle +import tempfile +import shutil +# WindowsError only exist in Windows +from nose.tools import assert_false + +try: + WindowsError +except NameError: + WindowsError = None + from copy import deepcopy import numpy as np @@ -52,6 +63,7 @@ BOSTON = None +_TEMP_READONLY_MEMMAP_MEMORY = None CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'] MULTI_OUTPUT = ['CCA', 'DecisionTreeRegressor', 'ElasticNet', 'ExtraTreeRegressor', 'ExtraTreesRegressor', 'GaussianProcess', @@ -79,6 +91,7 @@ def _yield_non_meta_checks(name, Estimator): yield check_fit_score_takes_y yield check_dtype_object yield check_estimators_fit_returns_self + yield check_estimators_fit_returns_self_readonly # Check that all estimator yield informative messages when # trained on empty datasets @@ -118,6 +131,7 @@ def _yield_classifier_checks(name, Classifier): # basic consistency testing yield check_classifiers_train yield check_classifiers_regression_target + yield check_classifiers_train_readonly if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"] # TODO some complication with -1 label and name not in ["DecisionTreeClassifier", @@ -137,6 +151,7 @@ def _yield_regressor_checks(name, Regressor): # TODO: test with multiple responses # basic testing yield check_regressors_train + yield check_regressors_train_readonly yield check_regressor_data_not_an_array yield check_estimators_partial_fit_n_features yield check_regressors_no_decision_function @@ -160,6 +175,7 @@ def _yield_transformer_checks(name, Transformer): 'FunctionTransformer', 'Normalizer']: # basic tests yield check_transformer_general + yield check_transformer_general_readonly yield check_transformers_unfitted @@ -169,6 +185,7 @@ def _yield_clustering_checks(name, Clusterer): # this is clustering on the features # let's not test that here. yield check_clustering + yield check_clustering_readonly yield check_estimators_partial_fit_n_features @@ -216,16 +233,99 @@ def check_estimator(Estimator): check(name, Estimator) -def _boston_subset(n_samples=200): +def _boston_subset(n_samples=200, scale_y=False, convert_y_2d=False): + """Utility function used to cache boston subset into a global variable""" global BOSTON if BOSTON is None: boston = load_boston() X, y = boston.data, boston.target X, y = shuffle(X, y, random_state=0) - X, y = X[:n_samples], y[:n_samples] X = StandardScaler().fit_transform(X) BOSTON = X, y - return BOSTON + else: + X, y = BOSTON + X, y = X[:n_samples], y[:n_samples] + if scale_y: + y = StandardScaler().fit_transform(y) + if convert_y_2d: + y = y[:, np.newaxis] + return X, y + + +def _readonly_boston_subset(*args, **kwargs): + """Utility function used to return a r-o memmap, without recreating + a new memory map at each call""" + _init_temp_memory() + f = _TEMP_READONLY_MEMMAP_MEMORY.cache(_boston_subset) + return f(*args, **kwargs) + + +def _boston_subset_with_mode(*args, **kwargs): + """Factorisation function used in checks""" + readonly = kwargs.pop('readonly', None) + if readonly: + return _readonly_boston_subset(*args, **kwargs) + else: + return _boston_subset(*args, **kwargs) + + +def _make_blobs(*args, **kwargs): + """Utility function used to ensure that + we have only positive value for X""" + positive = kwargs.pop('positive', False) + scale = kwargs.pop('scale', False) + shuffle_flag = kwargs.pop('shuffle', False) + X, y = make_blobs(*args, **kwargs) + if scale: + X = StandardScaler().fit_transform(X) + if positive: + X -= X.min() + if shuffle_flag: + X, y = shuffle(X, y, random_state=7) + return X, y + + +def _readonly_make_blobs(*args, **kwargs): + """Utility function used to return a r-o memmap, without recreating + a new memory map at each call""" + _init_temp_memory() + f = _TEMP_READONLY_MEMMAP_MEMORY.cache(_make_blobs) + return f(*args, **kwargs) + + +def _make_blobs_with_mode(*args, **kwargs): + """Factorisation function used in checks""" + readonly = kwargs.pop('readonly', None) + if readonly: + return _readonly_make_blobs(*args, **kwargs) + else: + return _make_blobs(*args, **kwargs) + + +def _init_temp_memory(mmap_mode='r'): + """Utility function used to initialize a temp folder""" + global _TEMP_READONLY_MEMMAP_MEMORY + if _TEMP_READONLY_MEMMAP_MEMORY is None: + temp_folder = tempfile.mkdtemp(prefix='sklearn_checks_temp_') + _TEMP_READONLY_MEMMAP_MEMORY = Memory(cachedir=temp_folder, + mmap_mode=mmap_mode, verbose=0) + # Cannot use atexit as it is called everytime a test end, + # thus forcing us to regenerate cache at every check + # atexit.register(_clear_temp_memory(warn=True)) + + +def _clear_temp_memory(warn=False): + """Utility function used to delete the local temp folder""" + global _TEMP_READONLY_MEMMAP_MEMORY + if _TEMP_READONLY_MEMMAP_MEMORY is not None: + # Recovering temp_folder + cachedir = os.path.dirname(_TEMP_READONLY_MEMMAP_MEMORY.cachedir) + _TEMP_READONLY_MEMMAP_MEMORY = None + try: + shutil.rmtree(cachedir) + except WindowsError: + if warn: + warnings.warn("Could not delete temporary folder %s" % cachedir) def set_testing_parameters(estimator): @@ -286,7 +386,7 @@ def set_testing_parameters(estimator): class NotAnArray(object): - " An object that is convertable to an array" + """An object that is convertable to an array""" def __init__(self, data): self.data = data @@ -490,10 +590,23 @@ def check_transformer_general(name, Transformer): random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) X -= X.min() + + +def check_transformer_general(name, Transformer, readonly=False): + X, y = _make_blobs_with_mode(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], + random_state=0, n_features=2, cluster_std=0.1, + readonly=readonly, positive=True, scale=True) + if readonly: + assert_false(X.flags['WRITEABLE']) _check_transformer(name, Transformer, X, y) _check_transformer(name, Transformer, X.tolist(), y.tolist()) +def check_transformer_general_readonly(name, Transformer): + check_transformer_general(name, Transformer, readonly=True) + + + def check_transformer_data_not_an_array(name, Transformer): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) @@ -508,7 +621,7 @@ def check_transformer_data_not_an_array(name, Transformer): def check_transformers_unfitted(name, Transformer): X, y = _boston_subset() - + with warnings.catch_warnings(record=True): transformer = Transformer() @@ -822,10 +935,10 @@ def check_estimators_partial_fit_n_features(name, Alg): assert_raises(ValueError, alg.partial_fit, X[:, :-1], y) -def check_clustering(name, Alg): - X, y = make_blobs(n_samples=50, random_state=1) - X, y = shuffle(X, y, random_state=7) - X = StandardScaler().fit_transform(X) +def check_clustering(name, Alg, readonly=False): + X, y = _make_blobs_with_mode(n_samples=50, random_state=1, + scale=True, + readonly=readonly, shuffle=True) n_samples, n_features = X.shape # catch deprecation and neighbors warnings with warnings.catch_warnings(record=True): @@ -839,6 +952,8 @@ def check_clustering(name, Alg): alg.set_params(max_iter=100) # fit + if readonly: + assert_false(X.flags['WRITEABLE']) alg.fit(X) # with lists alg.fit(X.tolist()) @@ -856,6 +971,10 @@ def check_clustering(name, Alg): assert_array_equal(pred, pred2) +def check_clustering_readonly(name, Alg): + check_clustering(name, Alg, readonly=True) + + def check_clusterer_compute_labels_predict(name, Clusterer): """Check that predict is invariant of compute_labels""" X, y = make_blobs(n_samples=20, random_state=0) @@ -907,13 +1026,20 @@ def check_classifiers_one_label(name, Classifier): @ignore_warnings # Warnings are raised by decision function -def check_classifiers_train(name, Classifier): - X_m, y_m = make_blobs(n_samples=300, random_state=0) - X_m, y_m = shuffle(X_m, y_m, random_state=7) - X_m = StandardScaler().fit_transform(X_m) - # generate binary problem from multi-class one - y_b = y_m[y_m != 2] - X_b = X_m[y_m != 2] +def check_classifiers_train(name, Classifier, readonly=False): + if name in ['BernoulliNB', 'MultinomialNB']: + positive = True + else: + positive = False + X_m, y_m = _make_blobs_with_mode(n_samples=300, + random_state=0, shuffle=True, + readonly=readonly, scale=True, + positive=positive) + # generate binary problem + X_b, y_b = _make_blobs_with_mode(n_samples=300, + random_state=0, shuffle=True, + readonly=readonly, scale=True, + positive=positive, centers=2) for (X, y) in [(X_m, y_m), (X_b, y_b)]: # catch deprecation warnings classes = np.unique(y) @@ -921,14 +1047,15 @@ def check_classifiers_train(name, Classifier): n_samples, n_features = X.shape with warnings.catch_warnings(record=True): classifier = Classifier() - if name in ['BernoulliNB', 'MultinomialNB']: - X -= X.min() set_testing_parameters(classifier) set_random_state(classifier) # raises error on malformed input for fit assert_raises(ValueError, classifier.fit, X, y[:-1]) # fit + if readonly: + assert_false(X.flags['WRITEABLE']) + assert_false(y.flags['WRITEABLE']) classifier.fit(X, y) # with lists classifier.fit(X.tolist(), y.tolist()) @@ -977,22 +1104,33 @@ def check_classifiers_train(name, Classifier): assert_raises(ValueError, classifier.predict_proba, X.T) -def check_estimators_fit_returns_self(name, Estimator): +def check_classifiers_train_readonly(name, Classifier): + check_classifiers_train(name, Classifier, readonly=True) + + +def check_estimators_fit_returns_self(name, Estimator, readonly=False): """Check if self is returned when calling fit""" - X, y = make_blobs(random_state=0, n_samples=9, n_features=4) + X, y = _make_blobs_with_mode(random_state=0, n_samples=9, n_features=4, + readonly=readonly, positive=True) y = multioutput_estimator_convert_y_2d(name, y) - # some want non-negative input - X -= X.min() estimator = Estimator() set_testing_parameters(estimator) set_random_state(estimator) + if readonly: + assert_false(X.flags['WRITEABLE']) + assert_false(y.flags['WRITEABLE']) assert_true(estimator.fit(X, y) is estimator) @ignore_warnings +def check_estimators_fit_returns_self_readonly(name, Estimator): + """Check if Estimator.fit does not fail on read only mem-mapped data""" + check_estimators_fit_returns_self(name, Estimator, readonly=True) + + def check_estimators_unfitted(name, Estimator): """Check that predict raises an exception in an unfitted estimator. @@ -1126,11 +1264,19 @@ def check_regressors_int(name, Regressor): assert_array_almost_equal(pred1, pred2, 2, name) -def check_regressors_train(name, Regressor): - X, y = _boston_subset() - y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled - y = y.ravel() - y = multioutput_estimator_convert_y_2d(name, y) +def check_regressors_train_readonly(name, Regressors): + check_regressors_train(name, Regressors, readonly=True) + + +def check_regressors_train(name, Regressor, readonly=False): + # Reproduce multioutput_convert_y_2d for read only boston subset + if name in (['MultiTaskElasticNetCV', 'MultiTaskLassoCV', + 'MultiTaskLasso', 'MultiTaskElasticNet']): + convert_y_2d = True + else: + convert_y_2d = False + X, y = _boston_subset_with_mode(readonly=readonly, scale_y=True, + convert_y_2d=convert_y_2d) rnd = np.random.RandomState(0) # catch deprecation warnings with warnings.catch_warnings(record=True): @@ -1150,7 +1296,12 @@ def check_regressors_train(name, Regressor): y_ = y_.T else: y_ = y + if readonly: + assert_false(X.flags['WRITEABLE']) + assert_false(y_.flags['WRITEABLE']) set_random_state(regressor) + if readonly: + assert_false(X.flags['WRITEABLE']) regressor.fit(X, y_) regressor.fit(X.tolist(), y_.tolist()) y_pred = regressor.predict(X) @@ -1163,6 +1314,32 @@ def check_regressors_train(name, Regressor): assert_greater(regressor.score(X, y_), 0.5) +def check_regressors_pickle(name, Regressor): + X, y = _boston_subset(scale_y=True) + y = multioutput_estimator_convert_y_2d(name, y) + rnd = np.random.RandomState(0) + # catch deprecation warnings + with warnings.catch_warnings(record=True): + regressor = Regressor() + set_testing_parameters(regressor) + if not hasattr(regressor, 'alphas') and hasattr(regressor, 'alpha'): + # linear regressors need to set alpha, but not generalized CV ones + regressor.alpha = 0.01 + + if name in CROSS_DECOMPOSITION: + y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))]) + y_ = y_.T + else: + y_ = y + regressor.fit(X, y_) + y_pred = regressor.predict(X) + # store old predictions + pickled_regressor = pickle.dumps(regressor) + unpickled_regressor = pickle.loads(pickled_regressor) + pickled_y_pred = unpickled_regressor.predict(X) + assert_array_almost_equal(pickled_y_pred, y_pred) + + @ignore_warnings def check_regressors_no_decision_function(name, Regressor): # checks whether regressors have decision_function or predict_proba diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index d577864fb709a..85479008c503b 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -10,7 +10,7 @@ import scipy.sparse as sp from nose.tools import assert_raises, assert_true, assert_false, assert_equal -from sklearn.utils.testing import assert_raises_regexp +from sklearn.utils.testing import assert_raises_regexp, TempMemmap from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import assert_warns @@ -442,6 +442,17 @@ def test_check_is_fitted(): assert_equal(None, check_is_fitted(ard, "coef_")) assert_equal(None, check_is_fitted(svr, "support_")) +def test_check_array_memmap(): + X = np.ones((4, 4)) + # Let memmap passed + with TempMemmap(X, mmap_mode='r') as X: + Z = check_array(X, copy=False) + assert_true(np.may_share_memory(X, Z)) + assert_false(Z.flags['WRITEABLE']) + Z = check_array(X, copy=True) + assert_false(np.may_share_memory(X, Z)) + assert_true(Z.flags['WRITEABLE']) + def test_check_consistent_length(): check_consistent_length([1], [2], [3], [4], [5]) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index e62f3b4ba6d47..4fa9bfdbc471c 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -344,6 +344,8 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, if isinstance(accept_sparse, str): accept_sparse = [accept_sparse] + # store reference to original array to check if copy is needed when function returns + array_orig = array # store whether originally we wanted numeric dtype dtype_numeric = dtype == "numeric" @@ -381,7 +383,10 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, array = _ensure_sparse_format(array, accept_sparse, dtype, copy, force_all_finite) else: - array = np.array(array, dtype=dtype, order=order, copy=copy) + # Do not physically copy memory map : if type(array) == np.memmap: + # type(array) == np.ndarray + # array.base is array_orig + array = np.asarray(array, dtype=dtype, order=order) if ensure_2d: if array.ndim == 1: @@ -396,10 +401,8 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, "X.reshape(1, -1) if it contains a single sample.", DeprecationWarning) array = np.atleast_2d(array) - # To ensure that array flags are maintained - array = np.array(array, dtype=dtype, order=order, copy=copy) - # make sure we acually converted to numeric: + # make sure we actually converted to numeric: if dtype_numeric and array.dtype.kind == "O": array = array.astype(np.float64) if not allow_nd and array.ndim >= 3: @@ -429,6 +432,10 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, msg = ("Data with input dtype %s was converted to %s%s." % (dtype_orig, array.dtype, context)) warnings.warn(msg, DataConversionWarning_) + + if copy and np.may_share_memory(array,array_orig): + array = np.array(array, dtype=dtype, order=order) + return array