Skip to content

Commit 9ea723f

Browse files
lesteveglemaitre
authored andcommitted
[MRG+1] Read-only memmap input data in common tests (#10663)
1 parent 27a804d commit 9ea723f

File tree

6 files changed

+167
-28
lines changed

6 files changed

+167
-28
lines changed

doc/whats_new/v0.20.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ Linear, kernelized and related models
224224
underlying implementation is not random.
225225
:issue:`9497` by :user:`Albert Thomas <albertcthomas>`.
226226

227+
Utils
228+
229+
- Avoid copying the data in :func:`utils.check_array` when the input data is a
230+
memmap (and ``copy=False``). :issue:`10663` by :user:`Arthur Mensch
231+
<arthurmensch>` and :user:`Loïc Estève <lesteve>`.
232+
227233
Miscellaneous
228234

229235
- Add ``filename`` attribute to datasets that have a CSV file.
@@ -541,3 +547,7 @@ Changes to estimator checks
541547

542548
- Add invariance tests for clustering metrics. :issue:`8102` by :user:`Ankita
543549
Sinha <anki08>` and :user:`Guillaume Lemaitre <glemaitre>`.
550+
551+
- Add tests in :func:`estimator_checks.check_estimator` to check that an
552+
estimator can handle read-only memmap input data. :issue:`10663` by
553+
:user:`Arthur Mensch <arthurmensch>` and :user:`Loïc Estève <lesteve>`.

sklearn/utils/estimator_checks.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import traceback
77
import pickle
88
from copy import deepcopy
9+
import struct
10+
from functools import partial
11+
912
import numpy as np
1013
from scipy import sparse
1114
from scipy.stats import rankdata
12-
import struct
1315

1416
from sklearn.externals.six.moves import zip
1517
from sklearn.externals.joblib import hash, Memory
@@ -33,6 +35,7 @@
3335
from sklearn.utils.testing import SkipTest
3436
from sklearn.utils.testing import ignore_warnings
3537
from sklearn.utils.testing import assert_dict_equal
38+
from sklearn.utils.testing import create_memmap_backed_data
3639
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
3740

3841

@@ -84,6 +87,7 @@ def _yield_non_meta_checks(name, estimator):
8487
yield check_sample_weights_pandas_series
8588
yield check_sample_weights_list
8689
yield check_estimators_fit_returns_self
90+
yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
8791
yield check_complex_data
8892

8993
# Check that all estimator yield informative messages when
@@ -123,6 +127,7 @@ def _yield_classifier_checks(name, classifier):
123127
yield check_estimators_partial_fit_n_features
124128
# basic consistency testing
125129
yield check_classifiers_train
130+
yield partial(check_classifiers_train, readonly_memmap=True)
126131
yield check_classifiers_regression_target
127132
if (name not in ["MultinomialNB", "ComplementNB", "LabelPropagation",
128133
"LabelSpreading"] and
@@ -171,6 +176,7 @@ def _yield_regressor_checks(name, regressor):
171176
# TODO: test with multiple responses
172177
# basic testing
173178
yield check_regressors_train
179+
yield partial(check_regressors_train, readonly_memmap=True)
174180
yield check_regressor_data_not_an_array
175181
yield check_estimators_partial_fit_n_features
176182
yield check_regressors_no_decision_function
@@ -196,6 +202,7 @@ def _yield_transformer_checks(name, transformer):
196202
'FunctionTransformer', 'Normalizer']:
197203
# basic tests
198204
yield check_transformer_general
205+
yield partial(check_transformer_general, readonly_memmap=True)
199206
yield check_transformers_unfitted
200207
# Dependent on external solvers and hence accessing the iter
201208
# param is non-trivial.
@@ -211,6 +218,7 @@ def _yield_clustering_checks(name, clusterer):
211218
# this is clustering on the features
212219
# let's not test that here.
213220
yield check_clustering
221+
yield partial(check_clustering, readonly_memmap=True)
214222
yield check_estimators_partial_fit_n_features
215223
yield check_non_transformer_estimators_n_iter
216224

@@ -223,6 +231,7 @@ def _yield_outliers_checks(name, estimator):
223231
# checks for estimators that can be used on a test set
224232
if hasattr(estimator, 'predict'):
225233
yield check_outliers_train
234+
yield partial(check_outliers_train, readonly_memmap=True)
226235
# test outlier detectors can handle non-array data
227236
yield check_classifier_data_not_an_array
228237
# test if NotFittedError is raised
@@ -799,14 +808,18 @@ def check_fit1d(name, estimator_orig):
799808

800809

801810
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
802-
def check_transformer_general(name, transformer):
811+
def check_transformer_general(name, transformer, readonly_memmap=False):
803812
X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
804813
random_state=0, n_features=2, cluster_std=0.1)
805814
X = StandardScaler().fit_transform(X)
806815
X -= X.min()
807816
if name == 'PowerTransformer':
808817
# Box-Cox requires positive, non-zero data
809818
X += 1
819+
820+
if readonly_memmap:
821+
X, y = create_memmap_backed_data([X, y])
822+
810823
_check_transformer(name, transformer, X, y)
811824
_check_transformer(name, transformer, X.tolist(), y.tolist())
812825

@@ -1165,11 +1178,17 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
11651178

11661179

11671180
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
1168-
def check_clustering(name, clusterer_orig):
1181+
def check_clustering(name, clusterer_orig, readonly_memmap=False):
11691182
clusterer = clone(clusterer_orig)
11701183
X, y = make_blobs(n_samples=50, random_state=1)
11711184
X, y = shuffle(X, y, random_state=7)
11721185
X = StandardScaler().fit_transform(X)
1186+
rng = np.random.RandomState(7)
1187+
X_noise = np.concatenate([X, rng.uniform(low=-3, high=3, size=(5, 2))])
1188+
1189+
if readonly_memmap:
1190+
X, y, X_noise = create_memmap_backed_data([X, y, X_noise])
1191+
11731192
n_samples, n_features = X.shape
11741193
# catch deprecation and neighbors warnings
11751194
if hasattr(clusterer, "n_clusters"):
@@ -1201,8 +1220,6 @@ def check_clustering(name, clusterer_orig):
12011220
assert_in(pred2.dtype, [np.dtype('int32'), np.dtype('int64')])
12021221

12031222
# Add noise to X to test the possible values of the labels
1204-
rng = np.random.RandomState(7)
1205-
X_noise = np.concatenate([X, rng.uniform(low=-3, high=3, size=(5, 2))])
12061223
labels = clusterer.fit_predict(X_noise)
12071224

12081225
# There should be at least one sample in every cluster. Equivalently
@@ -1273,20 +1290,26 @@ def check_classifiers_one_label(name, classifier_orig):
12731290

12741291

12751292
@ignore_warnings # Warnings are raised by decision function
1276-
def check_classifiers_train(name, classifier_orig):
1293+
def check_classifiers_train(name, classifier_orig, readonly_memmap=False):
12771294
X_m, y_m = make_blobs(n_samples=300, random_state=0)
12781295
X_m, y_m = shuffle(X_m, y_m, random_state=7)
12791296
X_m = StandardScaler().fit_transform(X_m)
12801297
# generate binary problem from multi-class one
12811298
y_b = y_m[y_m != 2]
12821299
X_b = X_m[y_m != 2]
1300+
1301+
if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']:
1302+
X_m -= X_m.min()
1303+
X_b -= X_b.min()
1304+
1305+
if readonly_memmap:
1306+
X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b])
1307+
12831308
for (X, y) in [(X_m, y_m), (X_b, y_b)]:
12841309
classes = np.unique(y)
12851310
n_classes = len(classes)
12861311
n_samples, n_features = X.shape
12871312
classifier = clone(classifier_orig)
1288-
if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB']:
1289-
X -= X.min()
12901313
X = pairwise_estimator_convert_X(X, classifier_orig)
12911314
set_random_state(classifier)
12921315
# raises error on malformed input for fit
@@ -1382,9 +1405,13 @@ def check_classifiers_train(name, classifier_orig):
13821405
assert_array_equal(np.argsort(y_log_prob), np.argsort(y_prob))
13831406

13841407

1385-
def check_outliers_train(name, estimator_orig):
1408+
def check_outliers_train(name, estimator_orig, readonly_memmap=True):
13861409
X, _ = make_blobs(n_samples=300, random_state=0)
13871410
X = shuffle(X, random_state=7)
1411+
1412+
if readonly_memmap:
1413+
X = create_memmap_backed_data(X)
1414+
13881415
n_samples, n_features = X.shape
13891416
estimator = clone(estimator_orig)
13901417
set_random_state(estimator)
@@ -1444,7 +1471,8 @@ def check_outliers_train(name, estimator_orig):
14441471

14451472

14461473
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
1447-
def check_estimators_fit_returns_self(name, estimator_orig):
1474+
def check_estimators_fit_returns_self(name, estimator_orig,
1475+
readonly_memmap=False):
14481476
"""Check if self is returned when calling fit"""
14491477
X, y = make_blobs(random_state=0, n_samples=9, n_features=4)
14501478
# some want non-negative input
@@ -1457,8 +1485,10 @@ def check_estimators_fit_returns_self(name, estimator_orig):
14571485
estimator = clone(estimator_orig)
14581486
y = multioutput_estimator_convert_y_2d(estimator, y)
14591487

1460-
set_random_state(estimator)
1488+
if readonly_memmap:
1489+
X, y = create_memmap_backed_data([X, y])
14611490

1491+
set_random_state(estimator)
14621492
assert_true(estimator.fit(X, y) is estimator)
14631493

14641494

@@ -1637,14 +1667,23 @@ def check_regressors_int(name, regressor_orig):
16371667

16381668

16391669
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
1640-
def check_regressors_train(name, regressor_orig):
1670+
def check_regressors_train(name, regressor_orig, readonly_memmap=False):
16411671
X, y = _boston_subset()
16421672
X = pairwise_estimator_convert_X(X, regressor_orig)
16431673
y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled
16441674
y = y.ravel()
16451675
regressor = clone(regressor_orig)
16461676
y = multioutput_estimator_convert_y_2d(regressor, y)
1647-
rnd = np.random.RandomState(0)
1677+
if name in CROSS_DECOMPOSITION:
1678+
rnd = np.random.RandomState(0)
1679+
y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))])
1680+
y_ = y_.T
1681+
else:
1682+
y_ = y
1683+
1684+
if readonly_memmap:
1685+
X, y, y_ = create_memmap_backed_data([X, y, y_])
1686+
16481687
if not hasattr(regressor, 'alphas') and hasattr(regressor, 'alpha'):
16491688
# linear regressors need to set alpha, but not generalized CV ones
16501689
regressor.alpha = 0.01
@@ -1659,11 +1698,6 @@ def check_regressors_train(name, regressor_orig):
16591698
"labels. Perhaps use check_X_y in fit.".format(name)):
16601699
regressor.fit(X, y[:-1])
16611700
# fit
1662-
if name in CROSS_DECOMPOSITION:
1663-
y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))])
1664-
y_ = y_.T
1665-
else:
1666-
y_ = y
16671701
set_random_state(regressor)
16681702
regressor.fit(X, y_)
16691703
regressor.fit(X.tolist(), y_.tolist())

sklearn/utils/testing.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import warnings
1717
import sys
1818
import struct
19+
import functools
1920

2021
import scipy as sp
2122
import scipy.io
@@ -766,21 +767,29 @@ def _delete_folder(folder_path, warn=False):
766767

767768
class TempMemmap(object):
768769
def __init__(self, data, mmap_mode='r'):
769-
self.temp_folder = tempfile.mkdtemp(prefix='sklearn_testing_')
770770
self.mmap_mode = mmap_mode
771771
self.data = data
772772

773773
def __enter__(self):
774-
fpath = op.join(self.temp_folder, 'data.pkl')
775-
joblib.dump(self.data, fpath)
776-
data_read_only = joblib.load(fpath, mmap_mode=self.mmap_mode)
777-
atexit.register(lambda: _delete_folder(self.temp_folder, warn=True))
774+
data_read_only, self.temp_folder = create_memmap_backed_data(
775+
self.data, mmap_mode=self.mmap_mode, return_folder=True)
778776
return data_read_only
779777

780778
def __exit__(self, exc_type, exc_val, exc_tb):
781779
_delete_folder(self.temp_folder)
782780

783781

782+
def create_memmap_backed_data(data, mmap_mode='r', return_folder=False):
783+
temp_folder = tempfile.mkdtemp(prefix='sklearn_testing_')
784+
atexit.register(functools.partial(_delete_folder, temp_folder, warn=True))
785+
filename = op.join(temp_folder, 'data.pkl')
786+
joblib.dump(data, filename)
787+
memmap_backed_data = joblib.load(filename, mmap_mode=mmap_mode)
788+
result = (memmap_backed_data if not return_folder
789+
else (memmap_backed_data, temp_folder))
790+
return result
791+
792+
784793
# Utils to test docstrings
785794

786795

sklearn/utils/tests/test_testing.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import warnings
22
import unittest
33
import sys
4+
import os
5+
import atexit
6+
47
import numpy as np
8+
59
from scipy import sparse
610

711
from sklearn.utils.deprecation import deprecated
812
from sklearn.utils.metaestimators import if_delegate_has_method
913
from sklearn.utils.testing import (
10-
assert_true,
1114
assert_raises,
1215
assert_less,
1316
assert_greater,
@@ -21,7 +24,10 @@
2124
ignore_warnings,
2225
check_docstring_parameters,
2326
assert_allclose_dense_sparse,
24-
assert_raises_regex)
27+
assert_raises_regex,
28+
TempMemmap,
29+
create_memmap_backed_data,
30+
_delete_folder)
2531

2632
from sklearn.utils.testing import SkipTest
2733
from sklearn.tree import DecisionTreeClassifier
@@ -478,3 +484,67 @@ def test_check_docstring_parameters():
478484
incorrect = check_docstring_parameters(f)
479485
assert len(incorrect) >= 1
480486
assert mess in incorrect[0], '"%s" not in "%s"' % (mess, incorrect[0])
487+
488+
489+
class RegistrationCounter(object):
490+
def __init__(self):
491+
self.nb_calls = 0
492+
493+
def __call__(self, to_register_func):
494+
self.nb_calls += 1
495+
assert to_register_func.func is _delete_folder
496+
497+
498+
def check_memmap(input_array, mmap_data, mmap_mode='r'):
499+
assert isinstance(mmap_data, np.memmap)
500+
writeable = mmap_mode != 'r'
501+
assert mmap_data.flags.writeable is writeable
502+
np.testing.assert_array_equal(input_array, mmap_data)
503+
504+
505+
def test_tempmemmap(monkeypatch):
506+
registration_counter = RegistrationCounter()
507+
monkeypatch.setattr(atexit, 'register', registration_counter)
508+
509+
input_array = np.ones(3)
510+
with TempMemmap(input_array) as data:
511+
check_memmap(input_array, data)
512+
temp_folder = os.path.dirname(data.filename)
513+
if os.name != 'nt':
514+
assert not os.path.exists(temp_folder)
515+
assert registration_counter.nb_calls == 1
516+
517+
mmap_mode = 'r+'
518+
with TempMemmap(input_array, mmap_mode=mmap_mode) as data:
519+
check_memmap(input_array, data, mmap_mode=mmap_mode)
520+
temp_folder = os.path.dirname(data.filename)
521+
if os.name != 'nt':
522+
assert not os.path.exists(temp_folder)
523+
assert registration_counter.nb_calls == 2
524+
525+
526+
def test_create_memmap_backed_data(monkeypatch):
527+
registration_counter = RegistrationCounter()
528+
monkeypatch.setattr(atexit, 'register', registration_counter)
529+
530+
input_array = np.ones(3)
531+
data = create_memmap_backed_data(input_array)
532+
check_memmap(input_array, data)
533+
assert registration_counter.nb_calls == 1
534+
535+
data, folder = create_memmap_backed_data(input_array,
536+
return_folder=True)
537+
check_memmap(input_array, data)
538+
assert folder == os.path.dirname(data.filename)
539+
assert registration_counter.nb_calls == 2
540+
541+
mmap_mode = 'r+'
542+
data = create_memmap_backed_data(input_array, mmap_mode=mmap_mode)
543+
check_memmap(input_array, data, mmap_mode)
544+
assert registration_counter.nb_calls == 3
545+
546+
input_list = [input_array, input_array + 1, input_array + 2]
547+
mmap_data_list = create_memmap_backed_data(input_list)
548+
for input_array, data in zip(input_list, mmap_data_list):
549+
check_memmap(input_array, data)
550+
assert registration_counter.nb_calls == 4

0 commit comments

Comments
 (0)