|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 |
|
| 3 | +import os |
| 4 | +import sys |
3 | 5 | import unicodedata
|
| 6 | +import urllib |
| 7 | +import zipfile |
| 8 | + |
| 9 | +import torch |
4 | 10 |
|
5 | 11 |
|
6 | 12 | def ispunct(token):
|
7 |
| - return all(unicodedata.category(char).startswith('P') |
8 |
| - for char in token) |
| 13 | + return all(unicodedata.category(char).startswith('P') for char in token) |
9 | 14 |
|
10 | 15 |
|
11 | 16 | def isfullwidth(token):
|
12 |
| - return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] |
13 |
| - for char in token) |
| 17 | + return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] for char in token) |
14 | 18 |
|
15 | 19 |
|
16 | 20 | def islatin(token):
|
17 |
| - return all('LATIN' in unicodedata.name(char) |
18 |
| - for char in token) |
| 21 | + return all('LATIN' in unicodedata.name(char) for char in token) |
19 | 22 |
|
20 | 23 |
|
21 | 24 | def isdigit(token):
|
22 |
| - return all('DIGIT' in unicodedata.name(char) |
23 |
| - for char in token) |
| 25 | + return all('DIGIT' in unicodedata.name(char) for char in token) |
24 | 26 |
|
25 | 27 |
|
26 | 28 | def tohalfwidth(token):
|
@@ -69,13 +71,35 @@ def stripe(x, n, w, offset=(0, 0), dim=1):
|
69 | 71 | storage_offset=(offset[0]*seq_len+offset[1])*numel)
|
70 | 72 |
|
71 | 73 |
|
72 |
| -def pad(tensors, padding_value=0, total_length=None): |
| 74 | +def pad(tensors, padding_value=0, total_length=None, padding_side='right'): |
73 | 75 | size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors)
|
74 | 76 | for i in range(len(tensors[0].size()))]
|
75 | 77 | if total_length is not None:
|
76 | 78 | assert total_length >= size[1]
|
77 | 79 | size[1] = total_length
|
78 | 80 | out_tensor = tensors[0].data.new(*size).fill_(padding_value)
|
79 | 81 | for i, tensor in enumerate(tensors):
|
80 |
| - out_tensor[i][[slice(0, i) for i in tensor.size()]] = tensor |
| 82 | + out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor |
81 | 83 | return out_tensor
|
| 84 | + |
| 85 | + |
| 86 | +def download(url, reload=False): |
| 87 | + path = os.path.join(os.path.expanduser('~/.cache/supar'), os.path.basename(urllib.parse.urlparse(url).path)) |
| 88 | + os.makedirs(os.path.dirname(path), exist_ok=True) |
| 89 | + if reload: |
| 90 | + os.remove(path) if os.path.exists(path) else None |
| 91 | + if not os.path.exists(path): |
| 92 | + sys.stderr.write(f"Downloading: {url} to {path}\n") |
| 93 | + try: |
| 94 | + torch.hub.download_url_to_file(url, path, progress=True) |
| 95 | + except urllib.error.URLError: |
| 96 | + raise RuntimeError(f"File {url} unavailable. Please try other sources.") |
| 97 | + if zipfile.is_zipfile(path): |
| 98 | + with zipfile.ZipFile(path) as f: |
| 99 | + members = f.infolist() |
| 100 | + path = os.path.join(os.path.dirname(path), members[0].filename) |
| 101 | + if len(members) != 1: |
| 102 | + raise RuntimeError('Only one file(not dir) is allowed in the zipfile.') |
| 103 | + if reload or not os.path.exists(path): |
| 104 | + f.extractall(os.path.dirname(path)) |
| 105 | + return path |
0 commit comments