Skip to content

Commit 74d1307

Browse files
fkdosilovicthomasjpfanlesteve
authored
ENH Add retry mechanism to fetch_xx functions. (#28160)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent e2f9530 commit 74d1307

File tree

11 files changed

+308
-29
lines changed

11 files changed

+308
-29
lines changed

doc/whats_new/v1.5.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,23 @@ Changelog
9696
only `inverse_func` is provided without `func` (that would default to identity) being
9797
explicitly set as well. :pr:`28483` by :user:`Stefanie Senger <StefanieSenger>`.
9898

99+
:mod:`sklearn.datasets`
100+
.......................
101+
102+
- |Enhancement| Adds optional arguments `n_retries` and `delay` to functions
103+
:func:`datasets.fetch_20newsgroups`,
104+
:func:`datasets.fetch_20newsgroups_vectorized`,
105+
:func:`datasets.fetch_california_housing`,
106+
:func:`datasets.fetch_covtype`,
107+
:func:`datasets.fetch_kddcup99`,
108+
:func:`datasets.fetch_lfw_pairs`,
109+
:func:`datasets.fetch_lfw_people`,
110+
:func:`datasets.fetch_olivetti_faces`,
111+
:func:`datasets.fetch_rcv1`,
112+
and :func:`datasets.fetch_species_distributions`.
113+
By default, the functions will retry up to 3 times in case of network failures.
114+
:pr:`28160` by :user:`Zhehao Liu <MaxwellLZH>` and :user:`Filip Karlo Došilović <fkdosilovic>`.
115+
99116
:mod:`sklearn.dummy`
100117
....................
101118

sklearn/datasets/_base.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
import hashlib
1212
import os
1313
import shutil
14+
import time
15+
import warnings
1416
from collections import namedtuple
1517
from importlib import resources
1618
from numbers import Integral
1719
from os import environ, listdir, makedirs
1820
from os.path import expanduser, isdir, join, splitext
1921
from pathlib import Path
22+
from urllib.error import URLError
2023
from urllib.request import urlretrieve
2124

2225
import numpy as np
@@ -1408,7 +1411,7 @@ def _sha256(path):
14081411
return sha256hash.hexdigest()
14091412

14101413

1411-
def _fetch_remote(remote, dirname=None):
1414+
def _fetch_remote(remote, dirname=None, n_retries=3, delay=1):
14121415
"""Helper function to download a remote dataset into path
14131416
14141417
Fetch a dataset pointed by remote's url, save into path using remote's
@@ -1424,14 +1427,35 @@ def _fetch_remote(remote, dirname=None):
14241427
dirname : str
14251428
Directory to save the file to.
14261429
1430+
n_retries : int, default=3
1431+
Number of retries when HTTP errors are encountered.
1432+
1433+
.. versionadded:: 1.5
1434+
1435+
delay : int, default=1
1436+
Number of seconds between retries.
1437+
1438+
.. versionadded:: 1.5
1439+
14271440
Returns
14281441
-------
14291442
file_path: str
14301443
Full path of the created file.
14311444
"""
14321445

14331446
file_path = remote.filename if dirname is None else join(dirname, remote.filename)
1434-
urlretrieve(remote.url, file_path)
1447+
while True:
1448+
try:
1449+
urlretrieve(remote.url, file_path)
1450+
break
1451+
except (URLError, TimeoutError):
1452+
if n_retries == 0:
1453+
# If no more retries are left, re-raise the caught exception.
1454+
raise
1455+
warnings.warn(f"Retry downloading from url: {remote.url}")
1456+
n_retries -= 1
1457+
time.sleep(delay)
1458+
14351459
checksum = _sha256(file_path)
14361460
if remote.checksum != checksum:
14371461
raise OSError(

sklearn/datasets/_california_housing.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323

2424
import logging
2525
import tarfile
26+
from numbers import Integral, Real
2627
from os import PathLike, makedirs, remove
2728
from os.path import exists
2829

2930
import joblib
3031
import numpy as np
3132

3233
from ..utils import Bunch
33-
from ..utils._param_validation import validate_params
34+
from ..utils._param_validation import Interval, validate_params
3435
from . import get_data_home
3536
from ._base import (
3637
RemoteFileMetadata,
@@ -57,11 +58,19 @@
5758
"download_if_missing": ["boolean"],
5859
"return_X_y": ["boolean"],
5960
"as_frame": ["boolean"],
61+
"n_retries": [Interval(Integral, 1, None, closed="left")],
62+
"delay": [Interval(Real, 0.0, None, closed="neither")],
6063
},
6164
prefer_skip_nested_validation=True,
6265
)
6366
def fetch_california_housing(
64-
*, data_home=None, download_if_missing=True, return_X_y=False, as_frame=False
67+
*,
68+
data_home=None,
69+
download_if_missing=True,
70+
return_X_y=False,
71+
as_frame=False,
72+
n_retries=3,
73+
delay=1.0,
6574
):
6675
"""Load the California housing dataset (regression).
6776
@@ -97,6 +106,16 @@ def fetch_california_housing(
97106
98107
.. versionadded:: 0.23
99108
109+
n_retries : int, default=3
110+
Number of retries when HTTP errors are encountered.
111+
112+
.. versionadded:: 1.5
113+
114+
delay : float, default=1.0
115+
Number of seconds between retries.
116+
117+
.. versionadded:: 1.5
118+
100119
Returns
101120
-------
102121
dataset : :class:`~sklearn.utils.Bunch`
@@ -154,7 +173,12 @@ def fetch_california_housing(
154173
"Downloading Cal. housing from {} to {}".format(ARCHIVE.url, data_home)
155174
)
156175

157-
archive_path = _fetch_remote(ARCHIVE, dirname=data_home)
176+
archive_path = _fetch_remote(
177+
ARCHIVE,
178+
dirname=data_home,
179+
n_retries=n_retries,
180+
delay=delay,
181+
)
158182

159183
with tarfile.open(mode="r:gz", name=archive_path) as f:
160184
cal_housing = np.loadtxt(

sklearn/datasets/_covtype.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
import logging
1818
import os
1919
from gzip import GzipFile
20+
from numbers import Integral, Real
2021
from os.path import exists, join
2122
from tempfile import TemporaryDirectory
2223

2324
import joblib
2425
import numpy as np
2526

2627
from ..utils import Bunch, check_random_state
27-
from ..utils._param_validation import validate_params
28+
from ..utils._param_validation import Interval, validate_params
2829
from . import get_data_home
2930
from ._base import (
3031
RemoteFileMetadata,
@@ -71,6 +72,8 @@
7172
"shuffle": ["boolean"],
7273
"return_X_y": ["boolean"],
7374
"as_frame": ["boolean"],
75+
"n_retries": [Interval(Integral, 1, None, closed="left")],
76+
"delay": [Interval(Real, 0.0, None, closed="neither")],
7477
},
7578
prefer_skip_nested_validation=True,
7679
)
@@ -82,6 +85,8 @@ def fetch_covtype(
8285
shuffle=False,
8386
return_X_y=False,
8487
as_frame=False,
88+
n_retries=3,
89+
delay=1.0,
8590
):
8691
"""Load the covertype dataset (classification).
8792
@@ -129,6 +134,16 @@ def fetch_covtype(
129134
130135
.. versionadded:: 0.24
131136
137+
n_retries : int, default=3
138+
Number of retries when HTTP errors are encountered.
139+
140+
.. versionadded:: 1.5
141+
142+
delay : float, default=1.0
143+
Number of seconds between retries.
144+
145+
.. versionadded:: 1.5
146+
132147
Returns
133148
-------
134149
dataset : :class:`~sklearn.utils.Bunch`
@@ -183,7 +198,9 @@ def fetch_covtype(
183198
# os.rename to atomically move the data files to their target location.
184199
with TemporaryDirectory(dir=covtype_dir) as temp_dir:
185200
logger.info(f"Downloading {ARCHIVE.url}")
186-
archive_path = _fetch_remote(ARCHIVE, dirname=temp_dir)
201+
archive_path = _fetch_remote(
202+
ARCHIVE, dirname=temp_dir, _retries=n_retries, delay=delay
203+
)
187204
Xy = np.genfromtxt(GzipFile(filename=archive_path), delimiter=",")
188205

189206
X = Xy[:, :-1]

sklearn/datasets/_kddcup99.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
import logging
1313
import os
1414
from gzip import GzipFile
15+
from numbers import Integral, Real
1516
from os.path import exists, join
1617

1718
import joblib
1819
import numpy as np
1920

2021
from ..utils import Bunch, check_random_state
2122
from ..utils import shuffle as shuffle_method
22-
from ..utils._param_validation import StrOptions, validate_params
23+
from ..utils._param_validation import Interval, StrOptions, validate_params
2324
from . import get_data_home
2425
from ._base import (
2526
RemoteFileMetadata,
@@ -57,6 +58,8 @@
5758
"download_if_missing": ["boolean"],
5859
"return_X_y": ["boolean"],
5960
"as_frame": ["boolean"],
61+
"n_retries": [Interval(Integral, 1, None, closed="left")],
62+
"delay": [Interval(Real, 0.0, None, closed="neither")],
6063
},
6164
prefer_skip_nested_validation=True,
6265
)
@@ -70,6 +73,8 @@ def fetch_kddcup99(
7073
download_if_missing=True,
7174
return_X_y=False,
7275
as_frame=False,
76+
n_retries=3,
77+
delay=1.0,
7378
):
7479
"""Load the kddcup99 dataset (classification).
7580
@@ -127,6 +132,16 @@ def fetch_kddcup99(
127132
128133
.. versionadded:: 0.24
129134
135+
n_retries : int, default=3
136+
Number of retries when HTTP errors are encountered.
137+
138+
.. versionadded:: 1.5
139+
140+
delay : float, default=1.0
141+
Number of seconds between retries.
142+
143+
.. versionadded:: 1.5
144+
130145
Returns
131146
-------
132147
data : :class:`~sklearn.utils.Bunch`
@@ -160,6 +175,8 @@ def fetch_kddcup99(
160175
data_home=data_home,
161176
percent10=percent10,
162177
download_if_missing=download_if_missing,
178+
n_retries=n_retries,
179+
delay=delay,
163180
)
164181

165182
data = kddcup99.data
@@ -243,7 +260,9 @@ def fetch_kddcup99(
243260
)
244261

245262

246-
def _fetch_brute_kddcup99(data_home=None, download_if_missing=True, percent10=True):
263+
def _fetch_brute_kddcup99(
264+
data_home=None, download_if_missing=True, percent10=True, n_retries=3, delay=1.0
265+
):
247266
"""Load the kddcup99 dataset, downloading it if necessary.
248267
249268
Parameters
@@ -259,6 +278,12 @@ def _fetch_brute_kddcup99(data_home=None, download_if_missing=True, percent10=Tr
259278
percent10 : bool, default=True
260279
Whether to load only 10 percent of the data.
261280
281+
n_retries : int, default=3
282+
Number of retries when HTTP errors are encountered.
283+
284+
delay : float, default=1.0
285+
Number of seconds between retries.
286+
262287
Returns
263288
-------
264289
dataset : :class:`~sklearn.utils.Bunch`
@@ -354,7 +379,7 @@ def _fetch_brute_kddcup99(data_home=None, download_if_missing=True, percent10=Tr
354379
elif download_if_missing:
355380
_mkdirp(kddcup_dir)
356381
logger.info("Downloading %s" % archive.url)
357-
_fetch_remote(archive, dirname=kddcup_dir)
382+
_fetch_remote(archive, dirname=kddcup_dir, n_retries=n_retries, delay=delay)
358383
DT = np.dtype(dt)
359384
logger.debug("extracting archive")
360385
archive_path = join(kddcup_dir, archive.filename)

0 commit comments

Comments
 (0)