Skip to content

Commit c727e56

Browse files
committed
Update docs
1 parent b3ab971 commit c727e56

File tree

5 files changed

+223
-37
lines changed

5 files changed

+223
-37
lines changed

docs/source/structs/fn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Functions
1+
Function
22
==================================================================
33

44
.. currentmodule:: supar.structs.fn

docs/source/utils/embed.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
Embedding
2+
==================================================================
3+
4+
.. currentmodule:: supar.utils.embed
5+
6+
Embedding
7+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8+
.. autoclass:: Embedding
9+
:members:
10+
11+
GloVeEmbedding
12+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13+
.. autoclass:: GloVeEmbedding
14+
:members:
15+
16+
FasttextEmbedding
17+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18+
.. autoclass:: FasttextEmbedding
19+
:members:
20+
21+
GigaEmbedding
22+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23+
.. autoclass:: GigaEmbedding
24+
:members:
25+
26+
TencentEmbedding
27+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28+
.. autoclass:: TencentEmbedding
29+
:members:

docs/source/utils/fn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Functions
1+
Function
22
==================================================================
33

44
.. currentmodule:: supar.utils.fn

docs/source/utils/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ Utils
1010
field
1111
transform
1212
vocab
13+
embed
1314
fn

supar/utils/embed.py

Lines changed: 191 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,69 @@
33
from __future__ import annotations
44

55
import os
6-
from collections import Counter
7-
from typing import Optional
6+
from typing import Iterable, Optional, Union
87

98
import torch
109
from supar.utils.common import CACHE
1110
from supar.utils.fn import download
1211
from supar.utils.logging import progress_bar
13-
from supar.utils.vocab import Vocab
1412
from torch.distributions.utils import lazy_property
1513

1614

1715
class Embedding(object):
16+
r"""
17+
Defines a container object for holding pretrained embeddings.
18+
This object is callable and behaves like :class:`torch.nn.Embedding`.
19+
For huge files, this object supports lazy loading, seeking to retrieve vectors from the disk on the fly if necessary.
20+
21+
Currently available embeddings:
22+
- `GloVe`_
23+
- `Fasttext`_
24+
- `Giga`_
25+
- `Tencent`_
26+
27+
Args:
28+
path (str):
29+
Path to the embedding file or short name registered in ``supar.utils.embed.PRETRAINED``.
30+
unk (Optional[str]):
31+
The string token used to represent OOV tokens. Default: ``None``.
32+
skip_first (bool)
33+
If ``True``, skips the first line of the embedding file. Default: ``False``.
34+
cache (bool):
35+
If ``True``, instead of loading entire embeddings into memory, seeks to load vectors from the disk once called.
36+
Default: ``True``.
37+
sep (str):
38+
Separator used by embedding file. Default: ``' '``.
39+
40+
Examples:
41+
>>> import torch.nn as nn
42+
>>> from supar.utils.embed import Embedding
43+
>>> glove = Embedding.load('glove-6b-100')
44+
>>> glove
45+
GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True)
46+
>>> fasttext = Embedding.load('fasttext-en')
47+
>>> fasttext
48+
FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True)
49+
>>> giga = Embedding.load('giga-100')
50+
>>> giga
51+
GigaEmbedding(n_tokens=372846, dim=100, cache=True)
52+
>>> indices = torch.tensor([glove.vocab[i.lower()] for i in ['She', 'enjoys', 'playing', 'tennis', '.']])
53+
>>> indices
54+
tensor([ 67, 8371, 697, 2140, 2])
55+
>>> glove(indices).shape
56+
torch.Size([5, 100])
57+
>>> glove(indices).equal(nn.Embedding.from_pretrained(glove.vectors)(indices))
58+
True
59+
60+
.. _GloVe:
61+
https://nlp.stanford.edu/projects/glove/
62+
.. _Fasttext:
63+
https://fasttext.cc/docs/en/crawl-vectors.html
64+
.. _Giga:
65+
https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip
66+
.. _Tencent:
67+
https://ai.tencent.com/ailab/nlp/zh/download.html
68+
"""
1869

1970
CACHE = os.path.join(CACHE, 'data/embeds')
2071

@@ -23,43 +74,61 @@ def __init__(
2374
path: str,
2475
unk: Optional[str] = None,
2576
skip_first: bool = False,
26-
split: str = ' ',
27-
cache: bool = False,
77+
cache: bool = True,
78+
sep: str = ' ',
2879
**kwargs
2980
) -> Embedding:
3081
super().__init__()
3182

3283
self.path = path
3384
self.unk = unk
3485
self.skip_first = skip_first
35-
self.split = split
3686
self.cache = cache
87+
self.sep = sep
3788
self.kwargs = kwargs
3889

39-
self.vocab = Vocab(Counter(self.tokens), unk_index=self.tokens.index(unk) if unk is not None else 0)
90+
self.vocab = {token: i for i, token in enumerate(self.tokens)}
4091

4192
def __len__(self):
4293
return len(self.vocab)
4394

44-
def __contains__(self, token):
45-
return token in self.vocab
46-
47-
def __getitem__(self, key):
48-
return self.vectors[self.vocab[key]]
49-
5095
def __repr__(self):
5196
s = f"{self.__class__.__name__}("
5297
s += f"n_tokens={len(self)}, dim={self.dim}"
5398
if self.unk is not None:
5499
s += f", unk={self.unk}"
55100
if self.skip_first:
56101
s += f", skip_first={self.skip_first}"
102+
if self.cache:
103+
s += f", cache={self.cache}"
57104
s += ")"
58105
return s
59106

107+
def __contains__(self, token):
108+
return token in self.vocab
109+
110+
def __getitem__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor:
111+
indices = key
112+
if not isinstance(indices, torch.Tensor):
113+
indices = torch.tensor(key)
114+
if self.cache:
115+
elems, indices = indices.unique(return_inverse=True)
116+
with open(self.path) as f:
117+
vectors = []
118+
for index in elems.tolist():
119+
f.seek(self.positions[index])
120+
vectors.append(list(map(float, f.readline().strip().split(self.sep)[1:])))
121+
vectors = torch.tensor(vectors)
122+
else:
123+
vectors = self.vectors
124+
return torch.embedding(vectors, indices)
125+
126+
def __call__(self, key: Union[int, Iterable[int], torch.Tensor]) -> torch.Tensor:
127+
return self[key]
128+
60129
@property
61130
def dim(self):
62-
return len(self[self.vocab[0]])
131+
return len(self[0])
63132

64133
@property
65134
def unk_index(self):
@@ -69,17 +138,31 @@ def unk_index(self):
69138

70139
@lazy_property
71140
def tokens(self):
72-
with open(self.path, 'r') as f:
141+
with open(self.path) as f:
73142
if self.skip_first:
74143
f.readline()
75-
return [line.split(self.split)[0] for line in progress_bar(f)]
144+
return [line.strip().split(self.sep)[0] for line in progress_bar(f)]
76145

77146
@lazy_property
78147
def vectors(self):
79-
with open(self.path, 'r') as f:
148+
with open(self.path) as f:
149+
if self.skip_first:
150+
f.readline()
151+
return torch.tensor([list(map(float, line.strip().split(self.sep)[1:])) for line in progress_bar(f)])
152+
153+
@lazy_property
154+
def positions(self):
155+
with open(self.path) as f:
80156
if self.skip_first:
81157
f.readline()
82-
return torch.tensor([list(map(float, line.strip().split(self.split)[1:])) for line in progress_bar(f)])
158+
positions = [f.tell()]
159+
while True:
160+
line = f.readline()
161+
if line:
162+
positions.append(f.tell())
163+
else:
164+
break
165+
return positions
83166

84167
@classmethod
85168
def load(cls, path: str, unk: Optional[str] = None, **kwargs) -> Embedding:
@@ -90,9 +173,29 @@ def load(cls, path: str, unk: Optional[str] = None, **kwargs) -> Embedding:
90173
return cls(path, unk, **kwargs)
91174

92175

93-
class GloveEmbedding(Embedding):
176+
class GloVeEmbedding(Embedding):
177+
178+
r"""
179+
`GloVe`_: Global Vectors for Word Representation.
180+
Training is performed on aggregated global word-word co-occurrence statistics from a corpus,
181+
and the resulting representations showcase interesting linear substructures of the word vector space.
182+
183+
Args:
184+
lang (str):
185+
Language code. Default: ``en``.
186+
reload (bool):
187+
If ``True``, forces a fresh download. Default: ``False``.
188+
189+
Examples:
190+
>>> from supar.utils.embed import Embedding
191+
>>> Embedding.load('glove-6b-100')
192+
GloVeEmbedding(n_tokens=400000, dim=100, unk=unk, cache=True)
94193
95-
def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwargs) -> GloveEmbedding:
194+
.. _GloVe:
195+
https://nlp.stanford.edu/projects/glove/
196+
"""
197+
198+
def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwargs) -> GloVeEmbedding:
96199
if src == '6B' or src == 'twitter.27B':
97200
url = f'https://nlp.stanford.edu/data/glove.{src}.zip'
98201
else:
@@ -106,6 +209,25 @@ def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwarg
106209

107210
class FasttextEmbedding(Embedding):
108211

212+
r"""
213+
`Fasttext`_ word embeddings for 157 languages, trained using CBOW, in dimension 300,
214+
with character n-grams of length 5, a window of size 5 and 10 negatives.
215+
216+
Args:
217+
lang (str):
218+
Language code. Default: ``en``.
219+
reload (bool):
220+
If ``True``, forces a fresh download. Default: ``False``.
221+
222+
Examples:
223+
>>> from supar.utils.embed import Embedding
224+
>>> Embedding.load('fasttext-en')
225+
FasttextEmbedding(n_tokens=2000000, dim=300, skip_first=True, cache=True)
226+
227+
.. _Fasttext:
228+
https://fasttext.cc/docs/en/crawl-vectors.html
229+
"""
230+
109231
def __init__(self, lang: str = 'en', reload=False, *args, **kwargs) -> FasttextEmbedding:
110232
url = f'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{lang}.300.vec.gz'
111233
path = os.path.join(self.CACHE, 'fasttext', f'cc.{lang}.300.vec')
@@ -117,6 +239,23 @@ def __init__(self, lang: str = 'en', reload=False, *args, **kwargs) -> FasttextE
117239

118240
class GigaEmbedding(Embedding):
119241

242+
r"""
243+
`Giga`_ word embeddings, trained on Chinese Gigaword Third Edition for Chinese using word2vec,
244+
used by :cite:`zhang-etal-2020-efficient` and :cite:`zhang-etal-2020-fast`.
245+
246+
Args:
247+
reload (bool):
248+
If ``True``, forces a fresh download. Default: ``False``.
249+
250+
Examples:
251+
>>> from supar.utils.embed import Embedding
252+
>>> Embedding.load('giga-100')
253+
GigaEmbedding(n_tokens=372846, dim=100, cache=True)
254+
255+
.. _Giga:
256+
https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip
257+
"""
258+
120259
def __init__(self, reload=False, *args, **kwargs) -> GigaEmbedding:
121260
url = 'https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip'
122261
path = os.path.join(self.CACHE, 'giga', 'giga.100.txt')
@@ -128,9 +267,26 @@ def __init__(self, reload=False, *args, **kwargs) -> GigaEmbedding:
128267

129268
class TencentEmbedding(Embedding):
130269

131-
def __init__(self, dim: int = 100, big: bool = False, reload=False, *args, **kwargs) -> TencentEmbedding:
132-
url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if big else "-s"}.tar.gz' # noqa
133-
name = f'tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if big else "-s"}'
270+
r"""
271+
`Tencent`_ word embeddings.
272+
The embeddings are trained on large-scale text collected from news, webpages, and novels with Directional Skip-Gram.
273+
100-dimension and 200-dimension embeddings for over 12 million Chinese words are provided.
274+
275+
Args:
276+
dim (int):
277+
Which dimension of the embeddings to use. Currently 100 and 200 are available. Default: 100.
278+
large (bool):
279+
If ``True``, uses large version with larger vocab size (12,287,936); 2,000,000 otherwise. Default: ``False``.
280+
reload (bool):
281+
If ``True``, forces a fresh download. Default: ``False``.
282+
283+
.. _Tencent:
284+
https://ai.tencent.com/ailab/nlp/zh/download.html
285+
"""
286+
287+
def __init__(self, dim: int = 100, large: bool = False, reload=False, *args, **kwargs) -> TencentEmbedding:
288+
url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}.tar.gz' # noqa
289+
name = f'tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if large else "-s"}'
134290
path = os.path.join(os.path.join(self.CACHE, 'tencent'), name, f'{name}.txt')
135291
if not os.path.exists(path) or reload:
136292
download(url, os.path.join(self.CACHE, 'tencent'), clean=True)
@@ -139,16 +295,16 @@ def __init__(self, dim: int = 100, big: bool = False, reload=False, *args, **kwa
139295

140296

141297
PRETRAINED = {
142-
'glove-6b-50': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 50},
143-
'glove-6b-100': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 100},
144-
'glove-6b-200': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 200},
145-
'glove-6b-300': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 300},
146-
'glove-42b-300': {'_target_': GloveEmbedding, 'src': '42B', 'dim': 300},
147-
'glove-840b-300': {'_target_': GloveEmbedding, 'src': '84B', 'dim': 300},
148-
'glove-twitter-27b-25': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 25},
149-
'glove-twitter-27b-50': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 50},
150-
'glove-twitter-27b-100': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 100},
151-
'glove-twitter-27b-200': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 200},
298+
'glove-6b-50': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 50},
299+
'glove-6b-100': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 100},
300+
'glove-6b-200': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 200},
301+
'glove-6b-300': {'_target_': GloVeEmbedding, 'src': '6B', 'dim': 300},
302+
'glove-42b-300': {'_target_': GloVeEmbedding, 'src': '42B', 'dim': 300},
303+
'glove-840b-300': {'_target_': GloVeEmbedding, 'src': '84B', 'dim': 300},
304+
'glove-twitter-27b-25': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 25},
305+
'glove-twitter-27b-50': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 50},
306+
'glove-twitter-27b-100': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 100},
307+
'glove-twitter-27b-200': {'_target_': GloVeEmbedding, 'src': 'twitter.27B', 'dim': 200},
152308
'fasttext-bg': {'_target_': FasttextEmbedding, 'lang': 'bg'},
153309
'fasttext-ca': {'_target_': FasttextEmbedding, 'lang': 'ca'},
154310
'fasttext-cs': {'_target_': FasttextEmbedding, 'lang': 'cs'},
@@ -163,7 +319,7 @@ def __init__(self, dim: int = 100, big: bool = False, reload=False, *args, **kwa
163319
'fasttext-ru': {'_target_': FasttextEmbedding, 'lang': 'ru'},
164320
'giga-100': {'_target_': GigaEmbedding},
165321
'tencent-100': {'_target_': TencentEmbedding, 'dim': 100},
166-
'tencent-100-b': {'_target_': TencentEmbedding, 'dim': 100, 'big': True},
322+
'tencent-100-b': {'_target_': TencentEmbedding, 'dim': 100, 'large': True},
167323
'tencent-200': {'_target_': TencentEmbedding, 'dim': 200},
168-
'tencent-200-b': {'_target_': TencentEmbedding, 'dim': 200, 'big': True},
324+
'tencent-200-b': {'_target_': TencentEmbedding, 'dim': 200, 'large': True},
169325
}

0 commit comments

Comments
 (0)