Skip to content

TST Download datasets before running pytest-xdist #19118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# doc/modules/clustering.rst and use sklearn from the local folder rather than
# the one from site-packages.

import os
import platform
import sys

Expand All @@ -17,18 +16,12 @@
from sklearn._min_dependencies import PYTEST_MIN_VERSION
from sklearn.utils.fixes import np_version, parse_version


if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION):
raise ImportError('Your version of pytest is too old, you should have '
'at least pytest >= {} installed.'
.format(PYTEST_MIN_VERSION))


def pytest_addoption(parser):
parser.addoption("--skip-network", action="store_true", default=False,
help="skip network tests")


def pytest_collection_modifyitems(config, items):
for item in items:
# FeatureHasher is not compatible with PyPy
Expand All @@ -50,15 +43,6 @@ def pytest_collection_modifyitems(config, items):
)
item.add_marker(marker)

# Skip tests which require internet if the flag is provided
if (config.getoption("--skip-network")
or int(os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "0"))):
skip_network = pytest.mark.skip(
reason="test requires internet connectivity")
for item in items:
if "network" in item.keywords:
item.add_marker(skip_network)

# numpy changed the str/repr formatting of numpy arrays in 1.14. We want to
# run doctests only for numpy >= 1.14.
skip_doctests = False
Expand Down
3 changes: 2 additions & 1 deletion doc/computing/parallelism.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,5 @@ These environment variables should be set before importing scikit-learn.
:SKLEARN_SKIP_NETWORK_TESTS:

When this environment variable is set to a non zero value, the tests
that need network access are skipped.
that need network access are skipped. When this environment variable is
not set then network tests are skipped.
85 changes: 85 additions & 0 deletions sklearn/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,94 @@
import os
from os import environ
from functools import wraps

import pytest
from threadpoolctl import threadpool_limits

from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
from sklearn.datasets import fetch_20newsgroups
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.datasets import fetch_california_housing
from sklearn.datasets import fetch_covtype
from sklearn.datasets import fetch_kddcup99
from sklearn.datasets import fetch_olivetti_faces
from sklearn.datasets import fetch_rcv1


dataset_fetchers = {
'fetch_20newsgroups_fxt': fetch_20newsgroups,
'fetch_20newsgroups_vectorized_fxt': fetch_20newsgroups_vectorized,
'fetch_california_housing_fxt': fetch_california_housing,
'fetch_covtype_fxt': fetch_covtype,
'fetch_kddcup99_fxt': fetch_kddcup99,
'fetch_olivetti_faces_fxt': fetch_olivetti_faces,
'fetch_rcv1_fxt': fetch_rcv1,
}


def _fetch_fixture(f):
"""Fetch dataset (download if missing and requested by environment)."""
download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'

@wraps(f)
def wrapped(*args, **kwargs):
kwargs['download_if_missing'] = download_if_missing
try:
return f(*args, **kwargs)
except IOError:
pytest.skip("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0")
return pytest.fixture(lambda: wrapped)


# Adds fixtures for fetching data
fetch_20newsgroups_fxt = _fetch_fixture(fetch_20newsgroups)
fetch_20newsgroups_vectorized_fxt = \
_fetch_fixture(fetch_20newsgroups_vectorized)
fetch_california_housing_fxt = _fetch_fixture(fetch_california_housing)
fetch_covtype_fxt = _fetch_fixture(fetch_covtype)
fetch_kddcup99_fxt = _fetch_fixture(fetch_kddcup99)
fetch_olivetti_faces_fxt = _fetch_fixture(fetch_olivetti_faces)
fetch_rcv1_fxt = _fetch_fixture(fetch_rcv1)


def pytest_collection_modifyitems(config, items):
"""Called after collect is completed.

Parameters
----------
config : pytest config
items : list of collected items
"""
run_network_tests = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'
skip_network = pytest.mark.skip(
reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0")

# download datasets during collection to avoid thread unsafe behavior
# when running pytest in parallel with pytest-xdist
dataset_features_set = set(dataset_fetchers)
datasets_to_download = set()

for item in items:
if not hasattr(item, "fixturenames"):
continue
item_fixtures = set(item.fixturenames)
dataset_to_fetch = item_fixtures & dataset_features_set
if not dataset_to_fetch:
continue

if run_network_tests:
datasets_to_download |= dataset_to_fetch
else:
# network tests are skipped
item.add_marker(skip_network)

# Only download datasets on the first worker spawned by pytest-xdist
# to avoid thread unsafe behavior. If pytest-xdist is not used, we still
# download before tests run.
worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0")
if worker_id == "gw0" and run_network_tests:
for name in datasets_to_download:
dataset_fetchers[name]()


@pytest.fixture(scope='function')
Expand Down
60 changes: 0 additions & 60 deletions sklearn/datasets/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,7 @@
""" Network tests are only run, if data is already locally available,
or if download is specifically requested by environment variable."""
import builtins
from functools import wraps
from os import environ
import pytest
from sklearn.datasets import fetch_20newsgroups
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.datasets import fetch_california_housing
from sklearn.datasets import fetch_covtype
from sklearn.datasets import fetch_kddcup99
from sklearn.datasets import fetch_olivetti_faces
from sklearn.datasets import fetch_rcv1


def _wrapped_fetch(f, dataset_name):
""" Fetch dataset (download if missing and requested by environment) """
download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'

@wraps(f)
def wrapped(*args, **kwargs):
kwargs['download_if_missing'] = download_if_missing
try:
return f(*args, **kwargs)
except IOError:
pytest.skip("Download {} to run this test".format(dataset_name))
return wrapped


@pytest.fixture
def fetch_20newsgroups_fxt():
return _wrapped_fetch(fetch_20newsgroups, dataset_name='20newsgroups')


@pytest.fixture
def fetch_20newsgroups_vectorized_fxt():
return _wrapped_fetch(fetch_20newsgroups_vectorized,
dataset_name='20newsgroups_vectorized')


@pytest.fixture
def fetch_california_housing_fxt():
return _wrapped_fetch(fetch_california_housing,
dataset_name='california_housing')


@pytest.fixture
def fetch_covtype_fxt():
return _wrapped_fetch(fetch_covtype, dataset_name='covtype')


@pytest.fixture
def fetch_kddcup99_fxt():
return _wrapped_fetch(fetch_kddcup99, dataset_name='kddcup99')


@pytest.fixture
def fetch_olivetti_faces_fxt():
return _wrapped_fetch(fetch_olivetti_faces, dataset_name='olivetti_faces')


@pytest.fixture
def fetch_rcv1_fxt():
return _wrapped_fetch(fetch_rcv1, dataset_name='rcv1')


@pytest.fixture
Expand Down
7 changes: 3 additions & 4 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sklearn import datasets
from sklearn.base import clone
from sklearn.datasets import (make_classification, fetch_california_housing,
from sklearn.datasets import (make_classification,
make_regression)
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
Expand Down Expand Up @@ -345,16 +345,15 @@ def test_max_feature_regression():
assert deviance < 0.5, "GB failed with deviance %.4f" % deviance


@pytest.mark.network
def test_feature_importance_regression():
def test_feature_importance_regression(fetch_california_housing_fxt):
"""Test that Gini importance is calculated correctly.

This test follows the example from [1]_ (pg. 373).

.. [1] Friedman, J., Hastie, T., & Tibshirani, R. (2001). The elements
of statistical learning. New York: Springer series in statistics.
"""
california = fetch_california_housing()
california = fetch_california_housing_fxt()
X, y = california.data, california.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

Expand Down