|
4 | 4 | import os
|
5 | 5 | import pickle
|
6 | 6 | import sys
|
| 7 | +import tarfile |
7 | 8 | import unicodedata
|
8 | 9 | import urllib
|
9 | 10 | import zipfile
|
@@ -248,24 +249,37 @@ def pad(
|
248 | 249 | return out_tensor
|
249 | 250 |
|
250 | 251 |
|
251 |
| -def download(url: str, reload: bool = False) -> str: |
252 |
| - path = os.path.join(CACHE, os.path.basename(urllib.parse.urlparse(url).path)) |
253 |
| - os.makedirs(os.path.dirname(path), exist_ok=True) |
254 |
| - if reload: |
255 |
| - os.remove(path) if os.path.exists(path) else None |
| 252 | +def download(url: str, path: Optional[str] = None, reload: bool = False, clean: bool = False) -> str: |
| 253 | + filename = os.path.basename(urllib.parse.urlparse(url).path) |
| 254 | + if path is None: |
| 255 | + path = CACHE |
| 256 | + os.makedirs(path, exist_ok=True) |
| 257 | + path = os.path.join(path, filename) |
| 258 | + if reload and os.path.exists(path): |
| 259 | + os.remove(path) |
256 | 260 | if not os.path.exists(path):
|
257 | 261 | sys.stderr.write(f"Downloading: {url} to {path}\n")
|
258 | 262 | try:
|
259 | 263 | torch.hub.download_url_to_file(url, path, progress=True)
|
260 | 264 | except (ValueError, urllib.error.URLError):
|
261 | 265 | raise RuntimeError(f"File {url} unavailable. Please try other sources.")
|
| 266 | + return extract(path, reload, clean) |
| 267 | + |
| 268 | + |
| 269 | +def extract(path: str, reload: bool = False, clean: bool = False) -> str: |
262 | 270 | if zipfile.is_zipfile(path):
|
263 | 271 | with zipfile.ZipFile(path) as f:
|
264 |
| - members = f.infolist() |
265 |
| - path = os.path.join(os.path.dirname(path), members[0].filename) |
266 |
| - if reload or not os.path.exists(path): |
| 272 | + extracted = os.path.join(os.path.dirname(path), f.infolist()[0].filename) |
| 273 | + if reload or not os.path.exists(extracted): |
| 274 | + f.extractall(os.path.dirname(path)) |
| 275 | + if tarfile.is_tarfile(path): |
| 276 | + with tarfile.open(path) as f: |
| 277 | + extracted = os.path.join(os.path.dirname(path), f.getnames()[0]) |
| 278 | + if reload or not os.path.exists(extracted): |
267 | 279 | f.extractall(os.path.dirname(path))
|
268 |
| - return path |
| 280 | + if clean: |
| 281 | + os.remove(path) |
| 282 | + return extracted |
269 | 283 |
|
270 | 284 |
|
271 | 285 | def binarize(data: Iterable, fbin: str = None) -> None:
|
|
0 commit comments