diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 482fa41f8f422..29648d8a5ed93 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -224,6 +224,12 @@ Linear, kernelized and related models underlying implementation is not random. :issue:`9497` by :user:`Albert Thomas `. +Utils + +- Avoid copying the data in :func:`utils.check_array` when the input data is a + memmap (and ``copy=False``). :issue:`10663` by :user:`Arthur Mensch + ` and :user:`Loïc Estève `. + Miscellaneous - Add ``filename`` attribute to datasets that have a CSV file. @@ -541,3 +547,7 @@ Changes to estimator checks - Add invariance tests for clustering metrics. :issue:`8102` by :user:`Ankita Sinha ` and :user:`Guillaume Lemaitre `. + +- Add tests in :func:`estimator_checks.check_estimator` to check that an + estimator can handle read-only memmap input data. :issue:`10663` by + :user:`Arthur Mensch ` and :user:`Loïc Estève `. diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 7de1175618f4c..29a5b83f1bdea 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -6,10 +6,12 @@ import traceback import pickle from copy import deepcopy +import struct +from functools import partial + import numpy as np from scipy import sparse from scipy.stats import rankdata -import struct from sklearn.externals.six.moves import zip from sklearn.externals.joblib import hash, Memory @@ -33,6 +35,7 @@ from sklearn.utils.testing import SkipTest from sklearn.utils.testing import ignore_warnings from sklearn.utils.testing import assert_dict_equal +from sklearn.utils.testing import create_memmap_backed_data from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -84,6 +87,7 @@ def _yield_non_meta_checks(name, estimator): yield check_sample_weights_pandas_series yield check_sample_weights_list yield check_estimators_fit_returns_self + yield partial(check_estimators_fit_returns_self, readonly_memmap=True) yield check_complex_data # Check that all estimator yield informative messages when @@ -123,6 +127,7 @@ def _yield_classifier_checks(name, classifier): yield check_estimators_partial_fit_n_features # basic consistency testing yield check_classifiers_train + yield partial(check_classifiers_train, readonly_memmap=True) yield check_classifiers_regression_target if (name not in ["MultinomialNB", "ComplementNB", "LabelPropagation", "LabelSpreading"] and @@ -171,6 +176,7 @@ def _yield_regressor_checks(name, regressor): # TODO: test with multiple responses # basic testing yield check_regressors_train + yield partial(check_regressors_train, readonly_memmap=True) yield check_regressor_data_not_an_array yield check_estimators_partial_fit_n_features yield check_regressors_no_decision_function @@ -196,6 +202,7 @@ def _yield_transformer_checks(name, transformer): 'FunctionTransformer', 'Normalizer']: # basic tests yield check_transformer_general + yield partial(check_transformer_general, readonly_memmap=True) yield check_transformers_unfitted # Dependent on external solvers and hence accessing the iter # param is non-trivial. @@ -211,6 +218,7 @@ def _yield_clustering_checks(name, clusterer): # this is clustering on the features # let's not test that here. yield check_clustering + yield partial(check_clustering, readonly_memmap=True) yield check_estimators_partial_fit_n_features yield check_non_transformer_estimators_n_iter @@ -223,6 +231,7 @@ def _yield_outliers_checks(name, estimator): # checks for estimators that can be used on a test set if hasattr(estimator, 'predict'): yield check_outliers_train + yield partial(check_outliers_train, readonly_memmap=True) # test outlier detectors can handle non-array data yield check_classifier_data_not_an_array # test if NotFittedError is raised @@ -799,7 +808,7 @@ def check_fit1d(name, estimator_orig): @ignore_warnings(category=(DeprecationWarning, FutureWarning)) -def check_transformer_general(name, transformer): +def check_transformer_general(name, transformer, readonly_memmap=False): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) X = StandardScaler().fit_transform(X) @@ -807,6 +816,10 @@ def check_transformer_general(name, transformer): if name == 'PowerTransformer': # Box-Cox requires positive, non-zero data X += 1 + + if readonly_memmap: + X, y = create_memmap_backed_data([X, y]) + _check_transformer(name, transformer, X, y) _check_transformer(name, transformer, X.tolist(), y.tolist()) @@ -1165,11 +1178,17 @@ def check_estimators_partial_fit_n_features(name, estimator_orig): @ignore_warnings(category=(DeprecationWarning, FutureWarning)) -def check_clustering(name, clusterer_orig): +def check_clustering(name, clusterer_orig, readonly_memmap=False): clusterer = clone(clusterer_orig) X, y = make_blobs(n_samples=50, random_state=1) X, y = shuffle(X, y, random_state=7) X = StandardScaler().fit_transform(X) + rng = np.random.RandomState(7) + X_noise = np.concatenate([X, rng.uniform(low=-3, high=3, size=(5, 2))]) + + if readonly_memmap: + X, y, X_noise = create_memmap_backed_data([X, y, X_noise]) + n_samples, n_features = X.shape # catch deprecation and neighbors warnings if hasattr(clusterer, "n_clusters"): @@ -1201,8 +1220,6 @@ def check_clustering(name, clusterer_orig): assert_in(pred2.dtype, [np.dtype('int32'), np.dtype('int64')]) # Add noise to X to test the possible values of the labels - rng = np.random.RandomState(7) - X_noise = np.concatenate([X, rng.uniform(low=-3, high=3, size=(5, 2))]) labels = clusterer.fit_predict(X_noise) # There should be at least one sample in every cluster. Equivalently @@ -1273,20 +1290,26 @@ def check_classifiers_one_label(name, classifier_orig): @ignore_warnings # Warnings are raised by decision function -def check_classifiers_train(name, classifier_orig): +def check_classifiers_train(name, classifier_orig, readonly_memmap=False): X_m, y_m = make_blobs(n_samples=300, random_state=0) X_m, y_m = shuffle(X_m, y_m, random_state=7) X_m = StandardScaler().fit_transform(X_m) # generate binary problem from multi-class one y_b = y_m[y_m != 2] X_b = X_m[y_m != 2] + + if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']: + X_m -= X_m.min() + X_b -= X_b.min() + + if readonly_memmap: + X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b]) + for (X, y) in [(X_m, y_m), (X_b, y_b)]: classes = np.unique(y) n_classes = len(classes) n_samples, n_features = X.shape classifier = clone(classifier_orig) - if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']: - X -= X.min() X = pairwise_estimator_convert_X(X, classifier_orig) set_random_state(classifier) # raises error on malformed input for fit @@ -1382,9 +1405,13 @@ def check_classifiers_train(name, classifier_orig): assert_array_equal(np.argsort(y_log_prob), np.argsort(y_prob)) -def check_outliers_train(name, estimator_orig): +def check_outliers_train(name, estimator_orig, readonly_memmap=True): X, _ = make_blobs(n_samples=300, random_state=0) X = shuffle(X, random_state=7) + + if readonly_memmap: + X = create_memmap_backed_data(X) + n_samples, n_features = X.shape estimator = clone(estimator_orig) set_random_state(estimator) @@ -1444,7 +1471,8 @@ def check_outliers_train(name, estimator_orig): @ignore_warnings(category=(DeprecationWarning, FutureWarning)) -def check_estimators_fit_returns_self(name, estimator_orig): +def check_estimators_fit_returns_self(name, estimator_orig, + readonly_memmap=False): """Check if self is returned when calling fit""" X, y = make_blobs(random_state=0, n_samples=9, n_features=4) # some want non-negative input @@ -1457,8 +1485,10 @@ def check_estimators_fit_returns_self(name, estimator_orig): estimator = clone(estimator_orig) y = multioutput_estimator_convert_y_2d(estimator, y) - set_random_state(estimator) + if readonly_memmap: + X, y = create_memmap_backed_data([X, y]) + set_random_state(estimator) assert_true(estimator.fit(X, y) is estimator) @@ -1637,14 +1667,23 @@ def check_regressors_int(name, regressor_orig): @ignore_warnings(category=(DeprecationWarning, FutureWarning)) -def check_regressors_train(name, regressor_orig): +def check_regressors_train(name, regressor_orig, readonly_memmap=False): X, y = _boston_subset() X = pairwise_estimator_convert_X(X, regressor_orig) y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled y = y.ravel() regressor = clone(regressor_orig) y = multioutput_estimator_convert_y_2d(regressor, y) - rnd = np.random.RandomState(0) + if name in CROSS_DECOMPOSITION: + rnd = np.random.RandomState(0) + y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))]) + y_ = y_.T + else: + y_ = y + + if readonly_memmap: + X, y, y_ = create_memmap_backed_data([X, y, y_]) + if not hasattr(regressor, 'alphas') and hasattr(regressor, 'alpha'): # linear regressors need to set alpha, but not generalized CV ones regressor.alpha = 0.01 @@ -1659,11 +1698,6 @@ def check_regressors_train(name, regressor_orig): "labels. Perhaps use check_X_y in fit.".format(name)): regressor.fit(X, y[:-1]) # fit - if name in CROSS_DECOMPOSITION: - y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))]) - y_ = y_.T - else: - y_ = y set_random_state(regressor) regressor.fit(X, y_) regressor.fit(X.tolist(), y_.tolist()) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 94972354e2751..d358c9b7e69d8 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -16,6 +16,7 @@ import warnings import sys import struct +import functools import scipy as sp import scipy.io @@ -766,21 +767,29 @@ def _delete_folder(folder_path, warn=False): class TempMemmap(object): def __init__(self, data, mmap_mode='r'): - self.temp_folder = tempfile.mkdtemp(prefix='sklearn_testing_') self.mmap_mode = mmap_mode self.data = data def __enter__(self): - fpath = op.join(self.temp_folder, 'data.pkl') - joblib.dump(self.data, fpath) - data_read_only = joblib.load(fpath, mmap_mode=self.mmap_mode) - atexit.register(lambda: _delete_folder(self.temp_folder, warn=True)) + data_read_only, self.temp_folder = create_memmap_backed_data( + self.data, mmap_mode=self.mmap_mode, return_folder=True) return data_read_only def __exit__(self, exc_type, exc_val, exc_tb): _delete_folder(self.temp_folder) +def create_memmap_backed_data(data, mmap_mode='r', return_folder=False): + temp_folder = tempfile.mkdtemp(prefix='sklearn_testing_') + atexit.register(functools.partial(_delete_folder, temp_folder, warn=True)) + filename = op.join(temp_folder, 'data.pkl') + joblib.dump(data, filename) + memmap_backed_data = joblib.load(filename, mmap_mode=mmap_mode) + result = (memmap_backed_data if not return_folder + else (memmap_backed_data, temp_folder)) + return result + + # Utils to test docstrings diff --git a/sklearn/utils/tests/test_testing.py b/sklearn/utils/tests/test_testing.py index 0aca27861e0bb..6b55431d21d7d 100644 --- a/sklearn/utils/tests/test_testing.py +++ b/sklearn/utils/tests/test_testing.py @@ -1,13 +1,16 @@ import warnings import unittest import sys +import os +import atexit + import numpy as np + from scipy import sparse from sklearn.utils.deprecation import deprecated from sklearn.utils.metaestimators import if_delegate_has_method from sklearn.utils.testing import ( - assert_true, assert_raises, assert_less, assert_greater, @@ -21,7 +24,10 @@ ignore_warnings, check_docstring_parameters, assert_allclose_dense_sparse, - assert_raises_regex) + assert_raises_regex, + TempMemmap, + create_memmap_backed_data, + _delete_folder) from sklearn.utils.testing import SkipTest from sklearn.tree import DecisionTreeClassifier @@ -478,3 +484,67 @@ def test_check_docstring_parameters(): incorrect = check_docstring_parameters(f) assert len(incorrect) >= 1 assert mess in incorrect[0], '"%s" not in "%s"' % (mess, incorrect[0]) + + +class RegistrationCounter(object): + def __init__(self): + self.nb_calls = 0 + + def __call__(self, to_register_func): + self.nb_calls += 1 + assert to_register_func.func is _delete_folder + + +def check_memmap(input_array, mmap_data, mmap_mode='r'): + assert isinstance(mmap_data, np.memmap) + writeable = mmap_mode != 'r' + assert mmap_data.flags.writeable is writeable + np.testing.assert_array_equal(input_array, mmap_data) + + +def test_tempmemmap(monkeypatch): + registration_counter = RegistrationCounter() + monkeypatch.setattr(atexit, 'register', registration_counter) + + input_array = np.ones(3) + with TempMemmap(input_array) as data: + check_memmap(input_array, data) + temp_folder = os.path.dirname(data.filename) + if os.name != 'nt': + assert not os.path.exists(temp_folder) + assert registration_counter.nb_calls == 1 + + mmap_mode = 'r+' + with TempMemmap(input_array, mmap_mode=mmap_mode) as data: + check_memmap(input_array, data, mmap_mode=mmap_mode) + temp_folder = os.path.dirname(data.filename) + if os.name != 'nt': + assert not os.path.exists(temp_folder) + assert registration_counter.nb_calls == 2 + + +def test_create_memmap_backed_data(monkeypatch): + registration_counter = RegistrationCounter() + monkeypatch.setattr(atexit, 'register', registration_counter) + + input_array = np.ones(3) + data = create_memmap_backed_data(input_array) + check_memmap(input_array, data) + assert registration_counter.nb_calls == 1 + + data, folder = create_memmap_backed_data(input_array, + return_folder=True) + check_memmap(input_array, data) + assert folder == os.path.dirname(data.filename) + assert registration_counter.nb_calls == 2 + + mmap_mode = 'r+' + data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode) + check_memmap(input_array, data, mmap_mode) + assert registration_counter.nb_calls == 3 + + input_list = [input_array, input_array + 1, input_array + 2] + mmap_data_list = create_memmap_backed_data(input_list) + for input_array, data in zip(input_list, mmap_data_list): + check_memmap(input_array, data) + assert registration_counter.nb_calls == 4 diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index a3a4175d7eff4..076e6d88440f1 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -43,6 +43,7 @@ from sklearn.exceptions import DataConversionWarning from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import TempMemmap def test_as_float_array(): @@ -690,3 +691,12 @@ def test_check_memory(): " have the same interface as " "sklearn.externals.joblib.Memory. Got memory='{}' " "instead.".format(dummy), check_memory, dummy) + + +@pytest.mark.parametrize('copy', [True, False]) +def test_check_array_memmap(copy): + X = np.ones((4, 4)) + with TempMemmap(X, mmap_mode='r') as X_memmap: + X_checked = check_array(X_memmap, copy=copy) + assert np.may_share_memory(X_memmap, X_checked) == (not copy) + assert X_checked.flags['WRITEABLE'] == copy diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 70e968ee6d36b..f4ad3b0223b14 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -437,6 +437,10 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, " instead.", DeprecationWarning) accept_sparse = False + # store reference to original array to check if copy is needed when + # function returns + array_orig = array + # store whether originally we wanted numeric dtype dtype_numeric = isinstance(dtype, six.string_types) and dtype == "numeric" @@ -487,7 +491,7 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, with warnings.catch_warnings(): try: warnings.simplefilter('error', ComplexWarning) - array = np.array(array, dtype=dtype, order=order, copy=copy) + array = np.asarray(array, dtype=dtype, order=order) except ComplexWarning: raise ValueError("Complex data not supported\n" "{}\n".format(array)) @@ -513,8 +517,6 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, "Reshape your data either using array.reshape(-1, 1) if " "your data has a single feature or array.reshape(1, -1) " "if it contains a single sample.".format(array)) - # To ensure that array flags are maintained - array = np.array(array, dtype=dtype, order=order, copy=copy) # in the future np.flexible dtypes will be handled like object dtypes if dtype_numeric and np.issubdtype(array.dtype, np.flexible): @@ -556,6 +558,10 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None, msg = ("Data with input dtype %s was converted to %s%s." % (dtype_orig, array.dtype, context)) warnings.warn(msg, DataConversionWarning) + + if copy and np.may_share_memory(array, array_orig): + array = np.array(array, dtype=dtype, order=order) + return array