Skip to content

FIX introduce a refresh_cache param to fetch_... functions. #14197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ Version 0.21.3
Changelog
---------

:mod:`sklearn.datasets`
.......................

- |Fix| :func:`datasets.fetch_california_housing`,
:func:`datasets.fetch_covtype`,
:func:`datasets.fetch_kddcup99`, :func:`datasets.fetch_olivetti_faces`,
:func:`datasets.fetch_rcv1`, and :func:`datasets.fetch_species_distributions`
try to persist the previously cache using the new ``joblib`` if the cahced
data was persisted using the deprecated ``sklearn.externals.joblib``. This
behavior is set to be deprecated and removed in v0.23.
:pr:`14197` by `Adrin Jalali`_.

:mod:`sklearn.impute`
.....................

Expand Down
29 changes: 29 additions & 0 deletions sklearn/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import csv
import sys
import shutil
import warnings
from collections import namedtuple
from os import environ, listdir, makedirs
from os.path import dirname, exists, expanduser, isdir, join, splitext
Expand Down Expand Up @@ -919,3 +920,31 @@ def _fetch_remote(remote, dirname=None):
"file may be corrupted.".format(file_path, checksum,
remote.checksum))
return file_path


def _refresh_cache(files, compress):
# TODO: REMOVE in v0.23
import joblib
msg = "sklearn.externals.joblib is deprecated in 0.21"
with warnings.catch_warnings(record=True) as warns:
data = tuple([joblib.load(f) for f in files])

refresh_needed = any([str(x.message).startswith(msg) for x in warns])

other_warns = [w for w in warns if not str(w.message).startswith(msg)]
for w in other_warns:
warnings.warn(message=w.message, category=w.category)

if refresh_needed:
try:
for value, path in zip(data, files):
joblib.dump(value, path, compress=compress)
except IOError:
message = ("This dataset will stop being loadable in scikit-learn "
"version 0.23 because it references a deprecated "
"import path. Consider removing the following files "
"and allowing it to be cached anew:\n%s"
% ("\n".join(files)))
warnings.warn(message=message, category=DeprecationWarning)

return data[0] if len(data) == 1 else data
5 changes: 4 additions & 1 deletion sklearn/datasets/california_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .base import _fetch_remote
from .base import _pkl_filepath
from .base import RemoteFileMetadata
from .base import _refresh_cache
from ..utils import Bunch

# The original data can be found at:
Expand Down Expand Up @@ -129,7 +130,9 @@ def fetch_california_housing(data_home=None, download_if_missing=True,
remove(archive_path)

else:
cal_housing = joblib.load(filepath)
cal_housing = _refresh_cache([filepath], 6)
# TODO: Revert to the following line in v0.23
# cal_housing = joblib.load(filepath)

feature_names = ["MedInc", "HouseAge", "AveRooms", "AveBedrms",
"Population", "AveOccup", "Latitude", "Longitude"]
Expand Down
7 changes: 5 additions & 2 deletions sklearn/datasets/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .base import get_data_home
from .base import _fetch_remote
from .base import RemoteFileMetadata
from .base import _refresh_cache
from ..utils import Bunch
from .base import _pkl_filepath
from ..utils import check_random_state
Expand Down Expand Up @@ -125,8 +126,10 @@ def fetch_covtype(data_home=None, download_if_missing=True,
try:
X, y
except NameError:
X = joblib.load(samples_path)
y = joblib.load(targets_path)
X, y = _refresh_cache([samples_path, targets_path], 9)
# TODO: Revert to the following two lines in v0.23
# X = joblib.load(samples_path)
# y = joblib.load(targets_path)

if shuffle:
ind = np.arange(X.shape[0])
Expand Down
7 changes: 5 additions & 2 deletions sklearn/datasets/kddcup99.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .base import _fetch_remote
from .base import get_data_home
from .base import RemoteFileMetadata
from .base import _refresh_cache
from ..utils import Bunch
from ..utils import check_random_state
from ..utils import shuffle as shuffle_method
Expand Down Expand Up @@ -292,8 +293,10 @@ def _fetch_brute_kddcup99(data_home=None,
try:
X, y
except NameError:
X = joblib.load(samples_path)
y = joblib.load(targets_path)
X, y = _refresh_cache([samples_path, targets_path], 0)
# TODO: Revert to the following two lines in v0.23
# X = joblib.load(samples_path)
# y = joblib.load(targets_path)

return Bunch(data=X, target=y)

Expand Down
5 changes: 4 additions & 1 deletion sklearn/datasets/olivetti_faces.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .base import _fetch_remote
from .base import RemoteFileMetadata
from .base import _pkl_filepath
from .base import _refresh_cache
from ..utils import check_random_state, Bunch

# The original data can be found at:
Expand Down Expand Up @@ -107,7 +108,9 @@ def fetch_olivetti_faces(data_home=None, shuffle=False, random_state=0,
joblib.dump(faces, filepath, compress=6)
del mfile
else:
faces = joblib.load(filepath)
faces = _refresh_cache([filepath], 6)
# TODO: Revert to the following line in v0.23
# faces = joblib.load(filepath)

# We want floating point data, but float32 is enough (there is only
# one byte of precision in the original uint8s anyway)
Expand Down
13 changes: 9 additions & 4 deletions sklearn/datasets/rcv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .base import _pkl_filepath
from .base import _fetch_remote
from .base import RemoteFileMetadata
from .base import _refresh_cache
from .svmlight_format import load_svmlight_files
from ..utils import shuffle as shuffle_
from ..utils import Bunch
Expand Down Expand Up @@ -189,8 +190,10 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True,
f.close()
remove(f.name)
else:
X = joblib.load(samples_path)
sample_id = joblib.load(sample_id_path)
X, sample_id = _refresh_cache([samples_path, sample_id_path], 9)
# TODO: Revert to the following two lines in v0.23
# X = joblib.load(samples_path)
# sample_id = joblib.load(sample_id_path)

# load target (y), categories, and sample_id_bis
if download_if_missing and (not exists(sample_topics_path) or
Expand Down Expand Up @@ -243,8 +246,10 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True,
joblib.dump(y, sample_topics_path, compress=9)
joblib.dump(categories, topics_path, compress=9)
else:
y = joblib.load(sample_topics_path)
categories = joblib.load(topics_path)
y, categories = _refresh_cache([sample_topics_path, topics_path], 9)
# TODO: Revert to the following two lines in v0.23
# y = joblib.load(sample_topics_path)
# categories = joblib.load(topics_path)

if subset == 'all':
pass
Expand Down
5 changes: 4 additions & 1 deletion sklearn/datasets/species_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .base import RemoteFileMetadata
from ..utils import Bunch
from .base import _pkl_filepath
from .base import _refresh_cache

# The original data can be found at:
# https://biodiversityinformatics.amnh.org/open_source/maxent/samples.zip
Expand Down Expand Up @@ -259,6 +260,8 @@ def fetch_species_distributions(data_home=None,
**extra_params)
joblib.dump(bunch, archive_path, compress=9)
else:
bunch = joblib.load(archive_path)
bunch = _refresh_cache([archive_path], 9)
# TODO: Revert to the following line in v0.23
# bunch = joblib.load(archive_path)

return bunch
54 changes: 54 additions & 0 deletions sklearn/datasets/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import partial

import pytest
import joblib

import numpy as np
from sklearn.datasets import get_data_home
Expand All @@ -23,6 +24,7 @@
from sklearn.datasets import load_boston
from sklearn.datasets import load_wine
from sklearn.datasets.base import Bunch
from sklearn.datasets.base import _refresh_cache
from sklearn.datasets.tests.test_common import check_return_X_y

from sklearn.externals._pilutil import pillow_installed
Expand Down Expand Up @@ -276,3 +278,55 @@ def test_bunch_dir():
# check that dir (important for autocomplete) shows attributes
data = load_iris()
assert "data" in dir(data)


def test_refresh_cache(monkeypatch):
# uses pytests monkeypatch fixture
# https://docs.pytest.org/en/latest/monkeypatch.html

def _load_warn(*args, **kwargs):
# raise the warning from "externals.joblib.__init__.py"
# this is raised when a file persisted by the old joblib is loaded now
msg = ("sklearn.externals.joblib is deprecated in 0.21 and will be "
"removed in 0.23. Please import this functionality directly "
"from joblib, which can be installed with: pip install joblib. "
"If this warning is raised when loading pickled models, you "
"may need to re-serialize those models with scikit-learn "
"0.21+.")
warnings.warn(msg, DeprecationWarning)
return 0

def _load_warn_unrelated(*args, **kwargs):
warnings.warn("unrelated warning", DeprecationWarning)
return 0

def _dump_safe(*args, **kwargs):
pass

def _dump_raise(*args, **kwargs):
# this happens if the file is read-only and joblib.dump fails to write
# on it.
raise IOError()

# test if the dataset spesific warning is raised if load raises the joblib
# warning, and dump fails to dump with new joblib
monkeypatch.setattr(joblib, "load", _load_warn)
monkeypatch.setattr(joblib, "dump", _dump_raise)
msg = "This dataset will stop being loadable in scikit-learn"
with pytest.warns(DeprecationWarning, match=msg):
_refresh_cache('test', 0)

# make sure no warning is raised if load raises the warning, but dump
# manages to dump the new data
monkeypatch.setattr(joblib, "load", _load_warn)
monkeypatch.setattr(joblib, "dump", _dump_safe)
with pytest.warns(None) as warns:
_refresh_cache('test', 0)
assert len(warns) == 0

# test if an unrelated warning is still passed through and not suppressed
# by _refresh_cache
monkeypatch.setattr(joblib, "load", _load_warn_unrelated)
monkeypatch.setattr(joblib, "dump", _dump_safe)
with pytest.warns(DeprecationWarning, match="unrelated warning"):
_refresh_cache('test', 0)