Skip to content

ENH Makes fetch_* more threadsafe #18000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Jul 26, 2020

Reference Issues/PRs

Related to #17996

What does this implement/fix? Explain your changes.

Implements suggestion by @ogrisel : #17996 (comment)

@ogrisel
Copy link
Member

ogrisel commented Jul 26, 2020

I concurrently did the following:

diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py
index dcbfde1c77..b02bb91c75 100644
--- a/sklearn/datasets/_base.py
+++ b/sklearn/datasets/_base.py
@@ -13,6 +13,7 @@ from collections import namedtuple
 from os import environ, listdir, makedirs
 from os.path import dirname, exists, expanduser, isdir, join, splitext
 import hashlib
+import uuid
 
 from ..utils import Bunch
 from ..utils import check_random_state
@@ -1161,6 +1162,9 @@ def _fetch_remote(remote, dirname=None):
     filename and ensure its integrity based on the SHA256 Checksum of the
     downloaded file.
 
+    The download is concurrent safe thanks to the use of a unique temporary
+    file in the same folder with a final atomic rename.
+
     Parameters
     ----------
     remote : RemoteFileMetadata
@@ -1178,11 +1182,18 @@ def _fetch_remote(remote, dirname=None):
 
     file_path = (remote.filename if dirname is None
                  else join(dirname, remote.filename))
-    urlretrieve(remote.url, file_path)
-    checksum = _sha256(file_path)
+
+    # Download to a temporary file to avoid corruption by concurrent
+    # downloads.
+    tmp_path = file_path + f".{uuid.uuid4().hex}.tmp"
+    urlretrieve(remote.url, tmp_path)
+    checksum = _sha256(tmp_path)
     if remote.checksum != checksum:
         raise IOError("{} has an SHA256 checksum ({}) "
                       "differing from expected ({}), "
-                      "file may be corrupted.".format(file_path, checksum,
+                      "file may be corrupted.".format(tmp_path, checksum,
                                                       remote.checksum))
+
+    # Atomic rename of successful download.
+    os.replace(tmp_path, file_path)
     return file_path

But this is not valid because the 20newsgroup fetcher downloads to a temporary folder with a fixed name which gets deleted afterwards.

And furthermore the .tar.gz extractall operation is not concurrent-safe... and I see no easy way to do it atomically.

@thomasjpfan
Copy link
Member Author

thomasjpfan commented Jul 26, 2020

I concurrently did the following:

LOL (not even a pun in this case)

I was thinking about this when working on #17553 as well. In that case, I had custom logic to figure out when to prefetch the dataset before any of the test ran.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants