3
3
from __future__ import annotations
4
4
5
5
import os
6
- from collections import Counter
7
- from typing import Optional
6
+ from typing import Iterable , Optional , Union
8
7
9
8
import torch
10
9
from supar .utils .common import CACHE
11
10
from supar .utils .fn import download
12
11
from supar .utils .logging import progress_bar
13
- from supar .utils .vocab import Vocab
14
12
from torch .distributions .utils import lazy_property
15
13
16
14
17
15
class Embedding (object ):
16
+ r"""
17
+ Defines a container object for holding pretrained embeddings.
18
+ This object is callable and behaves like :class:`torch.nn.Embedding`.
19
+ For huge files, this object supports lazy loading, seeking to retrieve vectors from the disk on the fly if necessary.
20
+
21
+ Currently available embeddings:
22
+ - `GloVe`_
23
+ - `Fasttext`_
24
+ - `Giga`_
25
+ - `Tencent`_
26
+
27
+ Args:
28
+ path (str):
29
+ Path to the embedding file or short name registered in ``supar.utils.embed.PRETRAINED``.
30
+ unk (Optional[str]):
31
+ The string token used to represent OOV tokens. Default: ``None``.
32
+ skip_first (bool)
33
+ If ``True``, skips the first line of the embedding file. Default: ``False``.
34
+ cache (bool):
35
+ If ``True``, instead of loading entire embeddings into memory, seeks to load vectors from the disk once called.
36
+ Default: ``True``.
37
+ sep (str):
38
+ Separator used by embedding file. Default: ``' '``.
39
+
40
+ Examples:
41
+ >>> import torch.nn as nn
42
+ >>> from supar.utils.embed import Embedding
43
+ >>> glove = Embedding.load('glove-6b-100')
44
+ >>> glove
45
+ GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True)
46
+ >>> fasttext = Embedding.load('fasttext-en')
47
+ >>> fasttext
48
+ FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True)
49
+ >>> giga = Embedding.load('giga-100')
50
+ >>> giga
51
+ GigaEmbedding(n_tokens=372846, dim=100, cache=True)
52
+ >>> indices = torch.tensor([glove.vocab[i.lower()] for i in ['She', 'enjoys', 'playing', 'tennis', '.']])
53
+ >>> indices
54
+ tensor([ 67, 8371, 697, 2140, 2])
55
+ >>> glove(indices).shape
56
+ torch.Size([5, 100])
57
+ >>> glove(indices).equal(nn.Embedding.from_pretrained(glove.vectors)(indices))
58
+ True
59
+
60
+ .. _GloVe:
61
+ https://nlp.stanford.edu/projects/glove/
62
+ .. _Fasttext:
63
+ https://fasttext.cc/docs/en/crawl-vectors.html
64
+ .. _Giga:
65
+ https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip
66
+ .. _Tencent:
67
+ https://ai.tencent.com/ailab/nlp/zh/download.html
68
+ """
18
69
19
70
CACHE = os .path .join (CACHE , 'data/embeds' )
20
71
@@ -23,43 +74,61 @@ def __init__(
23
74
path : str ,
24
75
unk : Optional [str ] = None ,
25
76
skip_first : bool = False ,
26
- split : str = ' ' ,
27
- cache : bool = False ,
77
+ cache : bool = True ,
78
+ sep : str = ' ' ,
28
79
** kwargs
29
80
) -> Embedding :
30
81
super ().__init__ ()
31
82
32
83
self .path = path
33
84
self .unk = unk
34
85
self .skip_first = skip_first
35
- self .split = split
36
86
self .cache = cache
87
+ self .sep = sep
37
88
self .kwargs = kwargs
38
89
39
- self .vocab = Vocab ( Counter ( self . tokens ), unk_index = self . tokens . index ( unk ) if unk is not None else 0 )
90
+ self .vocab = { token : i for i , token in enumerate ( self . tokens )}
40
91
41
92
def __len__ (self ):
42
93
return len (self .vocab )
43
94
44
- def __contains__ (self , token ):
45
- return token in self .vocab
46
-
47
- def __getitem__ (self , key ):
48
- return self .vectors [self .vocab [key ]]
49
-
50
95
def __repr__ (self ):
51
96
s = f"{ self .__class__ .__name__ } ("
52
97
s += f"n_tokens={ len (self )} , dim={ self .dim } "
53
98
if self .unk is not None :
54
99
s += f", unk={ self .unk } "
55
100
if self .skip_first :
56
101
s += f", skip_first={ self .skip_first } "
102
+ if self .cache :
103
+ s += f", cache={ self .cache } "
57
104
s += ")"
58
105
return s
59
106
107
+ def __contains__ (self , token ):
108
+ return token in self .vocab
109
+
110
+ def __getitem__ (self , key : Union [int , Iterable [int ], torch .Tensor ]) -> torch .Tensor :
111
+ indices = key
112
+ if not isinstance (indices , torch .Tensor ):
113
+ indices = torch .tensor (key )
114
+ if self .cache :
115
+ elems , indices = indices .unique (return_inverse = True )
116
+ with open (self .path ) as f :
117
+ vectors = []
118
+ for index in elems .tolist ():
119
+ f .seek (self .positions [index ])
120
+ vectors .append (list (map (float , f .readline ().strip ().split (self .sep )[1 :])))
121
+ vectors = torch .tensor (vectors )
122
+ else :
123
+ vectors = self .vectors
124
+ return torch .embedding (vectors , indices )
125
+
126
+ def __call__ (self , key : Union [int , Iterable [int ], torch .Tensor ]) -> torch .Tensor :
127
+ return self [key ]
128
+
60
129
@property
61
130
def dim (self ):
62
- return len (self [self . vocab [ 0 ] ])
131
+ return len (self [0 ])
63
132
64
133
@property
65
134
def unk_index (self ):
@@ -69,17 +138,31 @@ def unk_index(self):
69
138
70
139
@lazy_property
71
140
def tokens (self ):
72
- with open (self .path , 'r' ) as f :
141
+ with open (self .path ) as f :
73
142
if self .skip_first :
74
143
f .readline ()
75
- return [line .split (self .split )[0 ] for line in progress_bar (f )]
144
+ return [line .strip (). split (self .sep )[0 ] for line in progress_bar (f )]
76
145
77
146
@lazy_property
78
147
def vectors (self ):
79
- with open (self .path , 'r' ) as f :
148
+ with open (self .path ) as f :
149
+ if self .skip_first :
150
+ f .readline ()
151
+ return torch .tensor ([list (map (float , line .strip ().split (self .sep )[1 :])) for line in progress_bar (f )])
152
+
153
+ @lazy_property
154
+ def positions (self ):
155
+ with open (self .path ) as f :
80
156
if self .skip_first :
81
157
f .readline ()
82
- return torch .tensor ([list (map (float , line .strip ().split (self .split )[1 :])) for line in progress_bar (f )])
158
+ positions = [f .tell ()]
159
+ while True :
160
+ line = f .readline ()
161
+ if line :
162
+ positions .append (f .tell ())
163
+ else :
164
+ break
165
+ return positions
83
166
84
167
@classmethod
85
168
def load (cls , path : str , unk : Optional [str ] = None , ** kwargs ) -> Embedding :
@@ -90,9 +173,29 @@ def load(cls, path: str, unk: Optional[str] = None, **kwargs) -> Embedding:
90
173
return cls (path , unk , ** kwargs )
91
174
92
175
93
- class GloveEmbedding (Embedding ):
176
+ class GloVeEmbedding (Embedding ):
177
+
178
+ r"""
179
+ `GloVe`_: Global Vectors for Word Representation.
180
+ Training is performed on aggregated global word-word co-occurrence statistics from a corpus,
181
+ and the resulting representations showcase interesting linear substructures of the word vector space.
182
+
183
+ Args:
184
+ lang (str):
185
+ Language code. Default: ``en``.
186
+ reload (bool):
187
+ If ``True``, forces a fresh download. Default: ``False``.
188
+
189
+ Examples:
190
+ >>> from supar.utils.embed import Embedding
191
+ >>> Embedding.load('glove-6b-100')
192
+ GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True)
94
193
95
- def __init__ (self , src : str = '6B' , dim : int = 100 , reload = False , * args , ** kwargs ) -> GloveEmbedding :
194
+ .. _GloVe:
195
+ https://nlp.stanford.edu/projects/glove/
196
+ """
197
+
198
+ def __init__ (self , src : str = '6B' , dim : int = 100 , reload = False , * args , ** kwargs ) -> GloVeEmbedding :
96
199
if src == '6B' or src == 'twitter.27B' :
97
200
url = f'https://nlp.stanford.edu/data/glove.{ src } .zip'
98
201
else :
@@ -106,6 +209,25 @@ def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwarg
106
209
107
210
class FasttextEmbedding (Embedding ):
108
211
212
+ r"""
213
+ `Fasttext`_ word embeddings for 157 languages, trained using CBOW, in dimension 300,
214
+ with character n-grams of length 5, a window of size 5 and 10 negatives.
215
+
216
+ Args:
217
+ lang (str):
218
+ Language code. Default: ``en``.
219
+ reload (bool):
220
+ If ``True``, forces a fresh download. Default: ``False``.
221
+
222
+ Examples:
223
+ >>> from supar.utils.embed import Embedding
224
+ >>> Embedding.load('fasttext-en')
225
+ FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True)
226
+
227
+ .. _Fasttext:
228
+ https://fasttext.cc/docs/en/crawl-vectors.html
229
+ """
230
+
109
231
def __init__ (self , lang : str = 'en' , reload = False , * args , ** kwargs ) -> FasttextEmbedding :
110
232
url = f'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{ lang } .300.vec.gz'
111
233
path = os .path .join (self .CACHE , 'fasttext' , f'cc.{ lang } .300.vec' )
@@ -117,6 +239,23 @@ def __init__(self, lang: str = 'en', reload=False, *args, **kwargs) -> FasttextE
117
239
118
240
class GigaEmbedding (Embedding ):
119
241
242
+ r"""
243
+ `Giga`_ word embeddings, trained on Chinese Gigaword Third Edition for Chinese using word2vec,
244
+ used by :cite:`zhang-etal-2020-efficient` and :cite:`zhang-etal-2020-fast`.
245
+
246
+ Args:
247
+ reload (bool):
248
+ If ``True``, forces a fresh download. Default: ``False``.
249
+
250
+ Examples:
251
+ >>> from supar.utils.embed import Embedding
252
+ >>> Embedding.load('giga-100')
253
+ GigaEmbedding(n_tokens=372846, dim=100, cache=True)
254
+
255
+ .. _Giga:
256
+ https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip
257
+ """
258
+
120
259
def __init__ (self , reload = False , * args , ** kwargs ) -> GigaEmbedding :
121
260
url = 'https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip'
122
261
path = os .path .join (self .CACHE , 'giga' , 'giga.100.txt' )
@@ -128,9 +267,26 @@ def __init__(self, reload=False, *args, **kwargs) -> GigaEmbedding:
128
267
129
268
class TencentEmbedding (Embedding ):
130
269
131
- def __init__ (self , dim : int = 100 , big : bool = False , reload = False , * args , ** kwargs ) -> TencentEmbedding :
132
- url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{ dim } -v0.2.0{ "" if big else "-s" } .tar.gz' # noqa
133
- name = f'tencent-ailab-embedding-zh-d{ dim } -v0.2.0{ "" if big else "-s" } '
270
+ r"""
271
+ `Tencent`_ word embeddings.
272
+ The embeddings are trained on large-scale text collected from news, webpages, and novels with Directional Skip-Gram.
273
+ 100-dimension and 200-dimension embeddings for over 12 million Chinese words are provided.
274
+
275
+ Args:
276
+ dim (int):
277
+ Which dimension of the embeddings to use. Currently 100 and 200 are available. Default: 100.
278
+ large (bool):
279
+ If ``True``, uses large version with larger vocab size (12,287,936); 2,000,000 otherwise. Default: ``False``.
280
+ reload (bool):
281
+ If ``True``, forces a fresh download. Default: ``False``.
282
+
283
+ .. _Tencent:
284
+ https://ai.tencent.com/ailab/nlp/zh/download.html
285
+ """
286
+
287
+ def __init__ (self , dim : int = 100 , large : bool = False , reload = False , * args , ** kwargs ) -> TencentEmbedding :
288
+ url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{ dim } -v0.2.0{ "" if large else "-s" } .tar.gz' # noqa
289
+ name = f'tencent-ailab-embedding-zh-d{ dim } -v0.2.0{ "" if large else "-s" } '
134
290
path = os .path .join (os .path .join (self .CACHE , 'tencent' ), name , f'{ name } .txt' )
135
291
if not os .path .exists (path ) or reload :
136
292
download (url , os .path .join (self .CACHE , 'tencent' ), clean = True )
@@ -139,16 +295,16 @@ def __init__(self, dim: int = 100, big: bool = False, reload=False, *args, **kwa
139
295
140
296
141
297
PRETRAINED = {
142
- 'glove-6b-50' : {'_target_' : GloveEmbedding , 'src' : '6B' , 'dim' : 50 },
143
- 'glove-6b-100' : {'_target_' : GloveEmbedding , 'src' : '6B' , 'dim' : 100 },
144
- 'glove-6b-200' : {'_target_' : GloveEmbedding , 'src' : '6B' , 'dim' : 200 },
145
- 'glove-6b-300' : {'_target_' : GloveEmbedding , 'src' : '6B' , 'dim' : 300 },
146
- 'glove-42b-300' : {'_target_' : GloveEmbedding , 'src' : '42B' , 'dim' : 300 },
147
- 'glove-840b-300' : {'_target_' : GloveEmbedding , 'src' : '84B' , 'dim' : 300 },
148
- 'glove-twitter-27b-25' : {'_target_' : GloveEmbedding , 'src' : 'twitter.27B' , 'dim' : 25 },
149
- 'glove-twitter-27b-50' : {'_target_' : GloveEmbedding , 'src' : 'twitter.27B' , 'dim' : 50 },
150
- 'glove-twitter-27b-100' : {'_target_' : GloveEmbedding , 'src' : 'twitter.27B' , 'dim' : 100 },
151
- 'glove-twitter-27b-200' : {'_target_' : GloveEmbedding , 'src' : 'twitter.27B' , 'dim' : 200 },
298
+ 'glove-6b-50' : {'_target_' : GloVeEmbedding , 'src' : '6B' , 'dim' : 50 },
299
+ 'glove-6b-100' : {'_target_' : GloVeEmbedding , 'src' : '6B' , 'dim' : 100 },
300
+ 'glove-6b-200' : {'_target_' : GloVeEmbedding , 'src' : '6B' , 'dim' : 200 },
301
+ 'glove-6b-300' : {'_target_' : GloVeEmbedding , 'src' : '6B' , 'dim' : 300 },
302
+ 'glove-42b-300' : {'_target_' : GloVeEmbedding , 'src' : '42B' , 'dim' : 300 },
303
+ 'glove-840b-300' : {'_target_' : GloVeEmbedding , 'src' : '84B' , 'dim' : 300 },
304
+ 'glove-twitter-27b-25' : {'_target_' : GloVeEmbedding , 'src' : 'twitter.27B' , 'dim' : 25 },
305
+ 'glove-twitter-27b-50' : {'_target_' : GloVeEmbedding , 'src' : 'twitter.27B' , 'dim' : 50 },
306
+ 'glove-twitter-27b-100' : {'_target_' : GloVeEmbedding , 'src' : 'twitter.27B' , 'dim' : 100 },
307
+ 'glove-twitter-27b-200' : {'_target_' : GloVeEmbedding , 'src' : 'twitter.27B' , 'dim' : 200 },
152
308
'fasttext-bg' : {'_target_' : FasttextEmbedding , 'lang' : 'bg' },
153
309
'fasttext-ca' : {'_target_' : FasttextEmbedding , 'lang' : 'ca' },
154
310
'fasttext-cs' : {'_target_' : FasttextEmbedding , 'lang' : 'cs' },
@@ -163,7 +319,7 @@ def __init__(self, dim: int = 100, big: bool = False, reload=False, *args, **kwa
163
319
'fasttext-ru' : {'_target_' : FasttextEmbedding , 'lang' : 'ru' },
164
320
'giga-100' : {'_target_' : GigaEmbedding },
165
321
'tencent-100' : {'_target_' : TencentEmbedding , 'dim' : 100 },
166
- 'tencent-100-b' : {'_target_' : TencentEmbedding , 'dim' : 100 , 'big ' : True },
322
+ 'tencent-100-b' : {'_target_' : TencentEmbedding , 'dim' : 100 , 'large ' : True },
167
323
'tencent-200' : {'_target_' : TencentEmbedding , 'dim' : 200 },
168
- 'tencent-200-b' : {'_target_' : TencentEmbedding , 'dim' : 200 , 'big ' : True },
324
+ 'tencent-200-b' : {'_target_' : TencentEmbedding , 'dim' : 200 , 'large ' : True },
169
325
}
0 commit comments