diff --git a/conftest.py b/conftest.py index 5c48de4ac36a3..aec49c03ae13d 100644 --- a/conftest.py +++ b/conftest.py @@ -5,7 +5,6 @@ # doc/modules/clustering.rst and use sklearn from the local folder rather than # the one from site-packages. -import os import platform import sys @@ -17,18 +16,12 @@ from sklearn._min_dependencies import PYTEST_MIN_VERSION from sklearn.utils.fixes import np_version, parse_version - if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION): raise ImportError('Your version of pytest is too old, you should have ' 'at least pytest >= {} installed.' .format(PYTEST_MIN_VERSION)) -def pytest_addoption(parser): - parser.addoption("--skip-network", action="store_true", default=False, - help="skip network tests") - - def pytest_collection_modifyitems(config, items): for item in items: # FeatureHasher is not compatible with PyPy @@ -50,15 +43,6 @@ def pytest_collection_modifyitems(config, items): ) item.add_marker(marker) - # Skip tests which require internet if the flag is provided - if (config.getoption("--skip-network") - or int(os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "0"))): - skip_network = pytest.mark.skip( - reason="test requires internet connectivity") - for item in items: - if "network" in item.keywords: - item.add_marker(skip_network) - # numpy changed the str/repr formatting of numpy arrays in 1.14. We want to # run doctests only for numpy >= 1.14. skip_doctests = False diff --git a/doc/computing/parallelism.rst b/doc/computing/parallelism.rst index 3dce5ef66bb1d..8605650e8eec5 100644 --- a/doc/computing/parallelism.rst +++ b/doc/computing/parallelism.rst @@ -212,4 +212,5 @@ These environment variables should be set before importing scikit-learn. :SKLEARN_SKIP_NETWORK_TESTS: When this environment variable is set to a non zero value, the tests - that need network access are skipped. + that need network access are skipped. When this environment variable is + not set then network tests are skipped. diff --git a/sklearn/conftest.py b/sklearn/conftest.py index 8a98921342efa..2978115e3091c 100644 --- a/sklearn/conftest.py +++ b/sklearn/conftest.py @@ -1,9 +1,94 @@ import os +from os import environ +from functools import wraps import pytest from threadpoolctl import threadpool_limits from sklearn.utils._openmp_helpers import _openmp_effective_n_threads +from sklearn.datasets import fetch_20newsgroups +from sklearn.datasets import fetch_20newsgroups_vectorized +from sklearn.datasets import fetch_california_housing +from sklearn.datasets import fetch_covtype +from sklearn.datasets import fetch_kddcup99 +from sklearn.datasets import fetch_olivetti_faces +from sklearn.datasets import fetch_rcv1 + + +dataset_fetchers = { + 'fetch_20newsgroups_fxt': fetch_20newsgroups, + 'fetch_20newsgroups_vectorized_fxt': fetch_20newsgroups_vectorized, + 'fetch_california_housing_fxt': fetch_california_housing, + 'fetch_covtype_fxt': fetch_covtype, + 'fetch_kddcup99_fxt': fetch_kddcup99, + 'fetch_olivetti_faces_fxt': fetch_olivetti_faces, + 'fetch_rcv1_fxt': fetch_rcv1, +} + + +def _fetch_fixture(f): + """Fetch dataset (download if missing and requested by environment).""" + download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0' + + @wraps(f) + def wrapped(*args, **kwargs): + kwargs['download_if_missing'] = download_if_missing + try: + return f(*args, **kwargs) + except IOError: + pytest.skip("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0") + return pytest.fixture(lambda: wrapped) + + +# Adds fixtures for fetching data +fetch_20newsgroups_fxt = _fetch_fixture(fetch_20newsgroups) +fetch_20newsgroups_vectorized_fxt = \ + _fetch_fixture(fetch_20newsgroups_vectorized) +fetch_california_housing_fxt = _fetch_fixture(fetch_california_housing) +fetch_covtype_fxt = _fetch_fixture(fetch_covtype) +fetch_kddcup99_fxt = _fetch_fixture(fetch_kddcup99) +fetch_olivetti_faces_fxt = _fetch_fixture(fetch_olivetti_faces) +fetch_rcv1_fxt = _fetch_fixture(fetch_rcv1) + + +def pytest_collection_modifyitems(config, items): + """Called after collect is completed. + + Parameters + ---------- + config : pytest config + items : list of collected items + """ + run_network_tests = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0' + skip_network = pytest.mark.skip( + reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0") + + # download datasets during collection to avoid thread unsafe behavior + # when running pytest in parallel with pytest-xdist + dataset_features_set = set(dataset_fetchers) + datasets_to_download = set() + + for item in items: + if not hasattr(item, "fixturenames"): + continue + item_fixtures = set(item.fixturenames) + dataset_to_fetch = item_fixtures & dataset_features_set + if not dataset_to_fetch: + continue + + if run_network_tests: + datasets_to_download |= dataset_to_fetch + else: + # network tests are skipped + item.add_marker(skip_network) + + # Only download datasets on the first worker spawned by pytest-xdist + # to avoid thread unsafe behavior. If pytest-xdist is not used, we still + # download before tests run. + worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0") + if worker_id == "gw0" and run_network_tests: + for name in datasets_to_download: + dataset_fetchers[name]() @pytest.fixture(scope='function') diff --git a/sklearn/datasets/tests/conftest.py b/sklearn/datasets/tests/conftest.py index 4612cd5deb4bc..cf356d6ca3b10 100644 --- a/sklearn/datasets/tests/conftest.py +++ b/sklearn/datasets/tests/conftest.py @@ -1,67 +1,7 @@ """ Network tests are only run, if data is already locally available, or if download is specifically requested by environment variable.""" import builtins -from functools import wraps -from os import environ import pytest -from sklearn.datasets import fetch_20newsgroups -from sklearn.datasets import fetch_20newsgroups_vectorized -from sklearn.datasets import fetch_california_housing -from sklearn.datasets import fetch_covtype -from sklearn.datasets import fetch_kddcup99 -from sklearn.datasets import fetch_olivetti_faces -from sklearn.datasets import fetch_rcv1 - - -def _wrapped_fetch(f, dataset_name): - """ Fetch dataset (download if missing and requested by environment) """ - download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0' - - @wraps(f) - def wrapped(*args, **kwargs): - kwargs['download_if_missing'] = download_if_missing - try: - return f(*args, **kwargs) - except IOError: - pytest.skip("Download {} to run this test".format(dataset_name)) - return wrapped - - -@pytest.fixture -def fetch_20newsgroups_fxt(): - return _wrapped_fetch(fetch_20newsgroups, dataset_name='20newsgroups') - - -@pytest.fixture -def fetch_20newsgroups_vectorized_fxt(): - return _wrapped_fetch(fetch_20newsgroups_vectorized, - dataset_name='20newsgroups_vectorized') - - -@pytest.fixture -def fetch_california_housing_fxt(): - return _wrapped_fetch(fetch_california_housing, - dataset_name='california_housing') - - -@pytest.fixture -def fetch_covtype_fxt(): - return _wrapped_fetch(fetch_covtype, dataset_name='covtype') - - -@pytest.fixture -def fetch_kddcup99_fxt(): - return _wrapped_fetch(fetch_kddcup99, dataset_name='kddcup99') - - -@pytest.fixture -def fetch_olivetti_faces_fxt(): - return _wrapped_fetch(fetch_olivetti_faces, dataset_name='olivetti_faces') - - -@pytest.fixture -def fetch_rcv1_fxt(): - return _wrapped_fetch(fetch_rcv1, dataset_name='rcv1') @pytest.fixture diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 256b79db4865c..498e5bf38a675 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -13,7 +13,7 @@ from sklearn import datasets from sklearn.base import clone -from sklearn.datasets import (make_classification, fetch_california_housing, +from sklearn.datasets import (make_classification, make_regression) from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingRegressor @@ -345,8 +345,7 @@ def test_max_feature_regression(): assert deviance < 0.5, "GB failed with deviance %.4f" % deviance -@pytest.mark.network -def test_feature_importance_regression(): +def test_feature_importance_regression(fetch_california_housing_fxt): """Test that Gini importance is calculated correctly. This test follows the example from [1]_ (pg. 373). @@ -354,7 +353,7 @@ def test_feature_importance_regression(): .. [1] Friedman, J., Hastie, T., & Tibshirani, R. (2001). The elements of statistical learning. New York: Springer series in statistics. """ - california = fetch_california_housing() + california = fetch_california_housing_fxt() X, y = california.data, california.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)