Skip to content

Commit 34f1de6

Browse files
committed
download function
1 parent a246cb2 commit 34f1de6

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

supar/utils/fn.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
11
# -*- coding: utf-8 -*-
22

3+
import os
4+
import sys
35
import unicodedata
6+
import urllib
7+
import zipfile
8+
9+
import torch
410

511

612
def ispunct(token):
7-
return all(unicodedata.category(char).startswith('P')
8-
for char in token)
13+
return all(unicodedata.category(char).startswith('P') for char in token)
914

1015

1116
def isfullwidth(token):
12-
return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A']
13-
for char in token)
17+
return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] for char in token)
1418

1519

1620
def islatin(token):
17-
return all('LATIN' in unicodedata.name(char)
18-
for char in token)
21+
return all('LATIN' in unicodedata.name(char) for char in token)
1922

2023

2124
def isdigit(token):
22-
return all('DIGIT' in unicodedata.name(char)
23-
for char in token)
25+
return all('DIGIT' in unicodedata.name(char) for char in token)
2426

2527

2628
def tohalfwidth(token):
@@ -69,13 +71,35 @@ def stripe(x, n, w, offset=(0, 0), dim=1):
6971
storage_offset=(offset[0]*seq_len+offset[1])*numel)
7072

7173

72-
def pad(tensors, padding_value=0, total_length=None):
74+
def pad(tensors, padding_value=0, total_length=None, padding_side='right'):
7375
size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors)
7476
for i in range(len(tensors[0].size()))]
7577
if total_length is not None:
7678
assert total_length >= size[1]
7779
size[1] = total_length
7880
out_tensor = tensors[0].data.new(*size).fill_(padding_value)
7981
for i, tensor in enumerate(tensors):
80-
out_tensor[i][[slice(0, i) for i in tensor.size()]] = tensor
82+
out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor
8183
return out_tensor
84+
85+
86+
def download(url, reload=False):
87+
path = os.path.join(os.path.expanduser('~/.cache/supar'), os.path.basename(urllib.parse.urlparse(url).path))
88+
os.makedirs(os.path.dirname(path), exist_ok=True)
89+
if reload:
90+
os.remove(path) if os.path.exists(path) else None
91+
if not os.path.exists(path):
92+
sys.stderr.write(f"Downloading: {url} to {path}\n")
93+
try:
94+
torch.hub.download_url_to_file(url, path, progress=True)
95+
except urllib.error.URLError:
96+
raise RuntimeError(f"File {url} unavailable. Please try other sources.")
97+
if zipfile.is_zipfile(path):
98+
with zipfile.ZipFile(path) as f:
99+
members = f.infolist()
100+
path = os.path.join(os.path.dirname(path), members[0].filename)
101+
if len(members) != 1:
102+
raise RuntimeError('Only one file(not dir) is allowed in the zipfile.')
103+
if reload or not os.path.exists(path):
104+
f.extractall(os.path.dirname(path))
105+
return path

0 commit comments

Comments
 (0)