Skip to content

Commit 94a7394

Browse files
committed
Improve download fn
1 parent 76a8fb1 commit 94a7394

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

supar/utils/fn.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import pickle
66
import sys
7+
import tarfile
78
import unicodedata
89
import urllib
910
import zipfile
@@ -248,24 +249,37 @@ def pad(
248249
return out_tensor
249250

250251

251-
def download(url: str, reload: bool = False) -> str:
252-
path = os.path.join(CACHE, os.path.basename(urllib.parse.urlparse(url).path))
253-
os.makedirs(os.path.dirname(path), exist_ok=True)
254-
if reload:
255-
os.remove(path) if os.path.exists(path) else None
252+
def download(url: str, path: Optional[str] = None, reload: bool = False, clean: bool = False) -> str:
253+
filename = os.path.basename(urllib.parse.urlparse(url).path)
254+
if path is None:
255+
path = CACHE
256+
os.makedirs(path, exist_ok=True)
257+
path = os.path.join(path, filename)
258+
if reload and os.path.exists(path):
259+
os.remove(path)
256260
if not os.path.exists(path):
257261
sys.stderr.write(f"Downloading: {url} to {path}\n")
258262
try:
259263
torch.hub.download_url_to_file(url, path, progress=True)
260264
except (ValueError, urllib.error.URLError):
261265
raise RuntimeError(f"File {url} unavailable. Please try other sources.")
266+
return extract(path, reload, clean)
267+
268+
269+
def extract(path: str, reload: bool = False, clean: bool = False) -> str:
262270
if zipfile.is_zipfile(path):
263271
with zipfile.ZipFile(path) as f:
264-
members = f.infolist()
265-
path = os.path.join(os.path.dirname(path), members[0].filename)
266-
if reload or not os.path.exists(path):
272+
extracted = os.path.join(os.path.dirname(path), f.infolist()[0].filename)
273+
if reload or not os.path.exists(extracted):
274+
f.extractall(os.path.dirname(path))
275+
if tarfile.is_tarfile(path):
276+
with tarfile.open(path) as f:
277+
extracted = os.path.join(os.path.dirname(path), f.getnames()[0])
278+
if reload or not os.path.exists(extracted):
267279
f.extractall(os.path.dirname(path))
268-
return path
280+
if clean:
281+
os.remove(path)
282+
return extracted
269283

270284

271285
def binarize(data: Iterable, fbin: str = None) -> None:

0 commit comments

Comments
 (0)