Skip to content

Commit b3ab971

Browse files
committed
Many pretrained embeddings available
1 parent 0461050 commit b3ab971

File tree

16 files changed

+196
-86
lines changed

16 files changed

+196
-86
lines changed

EXAMPLES.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@ $ python -u -m supar.cmds.biaffine_dep train -b -d 0 -c biaffine-dep-en -p model
1414
--train ptb/train.conllx \
1515
--dev ptb/dev.conllx \
1616
--test ptb/test.conllx \
17-
--embed glove.6B.100d.txt \
18-
--unk unk
17+
--embed glove-6b-100
1918
# crf2o
2019
$ python -u -m supar.cmds.crf2o_dep train -b -d 0 -c crf2o-dep-en -p model -f char \
2120
--train ptb/train.conllx \
2221
--dev ptb/dev.conllx \
2322
--test ptb/test.conllx \
24-
--embed glove.6B.100d.txt \
25-
--unk unk \
23+
--embed glove-6b-100 \
2624
--mbr \
2725
--proj
2826
```
@@ -84,8 +82,7 @@ $ python -u -m supar.cmds.crf_con train -b -d 0 -c crf-con-en -p model -f char -
8482
--train ptb/train.pid \
8583
--dev ptb/dev.pid \
8684
--test ptb/test.pid \
87-
--embed glove.6B.100d.txt \
88-
--unk unk \
85+
--embed glove-6b-100 \
8986
--mbr
9087
```
9188

@@ -179,15 +176,13 @@ $ python -u -m supar.cmds.biaffine_sdp train -b -c biaffine-sdp-en -d 0 -f tag c
179176
--train dm/train.conllu \
180177
--dev dm/dev.conllu \
181178
--test dm/test.conllu \
182-
--embed glove.6B.100d.txt \
183-
--unk unk
179+
--embed glove-6b-100
184180
# vi
185181
$ python -u -m supar.cmds.vi_sdp train -b -c vi-sdp-en -d 1 -f tag char lemma -p model \
186182
--train dm/train.conllu \
187183
--dev dm/dev.conllu \
188184
--test dm/test.conllu \
189-
--embed glove.6B.100d.txt \
190-
--unk unk \
185+
--embed glove-6b-100 \
191186
--inference mfvi
192187
```
193188

supar/cmds/biaffine_dep.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def main():
2525
subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
2626
subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
2727
subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
28-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
29-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
30-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
28+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
3129
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
3230
# evaluate
3331
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')

supar/cmds/biaffine_sdp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def main():
2121
subparser.add_argument('--train', default='data/sdp/DM/train.conllu', help='path to train file')
2222
subparser.add_argument('--dev', default='data/sdp/DM/dev.conllu', help='path to dev file')
2323
subparser.add_argument('--test', default='data/sdp/DM/test.conllu', help='path to test file')
24-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
25-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
26-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
24+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
2725
subparser.add_argument('--n-embed-proj', default=125, type=int, help='dimension of projected embeddings')
2826
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
2927
# evaluate

supar/cmds/crf2o_dep.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def main():
2626
subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
2727
subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
2828
subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
29-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
30-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
31-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
29+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
3230
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
3331
# evaluate
3432
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')

supar/cmds/crf_con.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ def main():
2222
subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
2323
subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
2424
subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
25-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
26-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
27-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
25+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
2826
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
2927
# evaluate
3028
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')

supar/cmds/crf_dep.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ def main():
2626
subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
2727
subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
2828
subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
29-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
30-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
31-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
29+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
3230
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
3331
# evaluate
3432
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')

supar/cmds/vi_con.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def main():
2121
subparser.add_argument('--train', default='data/ptb/train.pid', help='path to train file')
2222
subparser.add_argument('--dev', default='data/ptb/dev.pid', help='path to dev file')
2323
subparser.add_argument('--test', default='data/ptb/test.pid', help='path to test file')
24-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
25-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
26-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
24+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
2725
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
2826
subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods')
2927
# evaluate

supar/cmds/vi_dep.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def main():
2525
subparser.add_argument('--train', default='data/ptb/train.conllx', help='path to train file')
2626
subparser.add_argument('--dev', default='data/ptb/dev.conllx', help='path to dev file')
2727
subparser.add_argument('--test', default='data/ptb/test.conllx', help='path to test file')
28-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
29-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
30-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
28+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
3129
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
3230
subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods')
3331
# evaluate

supar/cmds/vi_sdp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def main():
2121
subparser.add_argument('--train', default='data/sdp/DM/train.conllu', help='path to train file')
2222
subparser.add_argument('--dev', default='data/sdp/DM/dev.conllu', help='path to dev file')
2323
subparser.add_argument('--test', default='data/sdp/DM/test.conllu', help='path to test file')
24-
subparser.add_argument('--embed', default='data/glove.6B.100d.txt', help='path to pretrained embeddings')
25-
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
26-
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
24+
subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
2725
subparser.add_argument('--n-embed-proj', default=125, type=int, help='dimension of projected embeddings')
2826
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
2927
subparser.add_argument('--inference', default='mfvi', choices=['mfvi', 'lbp'], help='approximate inference methods')

supar/parsers/const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
291291

292292
train = Dataset(transform, args.train)
293293
if args.encoder != 'bert':
294-
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
294+
WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
295295
if TAG is not None:
296296
TAG.build(train)
297297
if CHAR is not None:

supar/parsers/dep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
298298

299299
train = Dataset(transform, args.train)
300300
if args.encoder != 'bert':
301-
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
301+
WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
302302
if TAG is not None:
303303
TAG.build(train)
304304
if CHAR is not None:
@@ -823,7 +823,7 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
823823

824824
train = Dataset(transform, args.train)
825825
if args.encoder != 'bert':
826-
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
826+
WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
827827
if TAG is not None:
828828
TAG.build(train)
829829
if CHAR is not None:

supar/parsers/sdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def build(cls, path, min_freq=7, fix_len=20, **kwargs):
265265

266266
train = Dataset(transform, args.train)
267267
if args.encoder != 'bert':
268-
WORD.build(train, args.min_freq, (Embedding.load(args.embed, args.unk) if args.embed else None))
268+
WORD.build(train, args.min_freq, (Embedding.load(args.embed) if args.embed else None), lambda x: x / torch.std(x))
269269
if TAG is not None:
270270
TAG.build(train)
271271
if CHAR is not None:

supar/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from . import field, fn, metric, transform
44
from .config import Config
55
from .data import Dataset
6-
from .embedding import Embedding
6+
from .embed import Embedding
77
from .field import ChartField, Field, RawField, SubwordField
88
from .transform import CoNLL, Transform, Tree
99
from .vocab import Vocab

supar/utils/embed.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from __future__ import annotations
4+
5+
import os
6+
from collections import Counter
7+
from typing import Optional
8+
9+
import torch
10+
from supar.utils.common import CACHE
11+
from supar.utils.fn import download
12+
from supar.utils.logging import progress_bar
13+
from supar.utils.vocab import Vocab
14+
from torch.distributions.utils import lazy_property
15+
16+
17+
class Embedding(object):
18+
19+
CACHE = os.path.join(CACHE, 'data/embeds')
20+
21+
def __init__(
22+
self,
23+
path: str,
24+
unk: Optional[str] = None,
25+
skip_first: bool = False,
26+
split: str = ' ',
27+
cache: bool = False,
28+
**kwargs
29+
) -> Embedding:
30+
super().__init__()
31+
32+
self.path = path
33+
self.unk = unk
34+
self.skip_first = skip_first
35+
self.split = split
36+
self.cache = cache
37+
self.kwargs = kwargs
38+
39+
self.vocab = Vocab(Counter(self.tokens), unk_index=self.tokens.index(unk) if unk is not None else 0)
40+
41+
def __len__(self):
42+
return len(self.vocab)
43+
44+
def __contains__(self, token):
45+
return token in self.vocab
46+
47+
def __getitem__(self, key):
48+
return self.vectors[self.vocab[key]]
49+
50+
def __repr__(self):
51+
s = f"{self.__class__.__name__}("
52+
s += f"n_tokens={len(self)}, dim={self.dim}"
53+
if self.unk is not None:
54+
s += f", unk={self.unk}"
55+
if self.skip_first:
56+
s += f", skip_first={self.skip_first}"
57+
s += ")"
58+
return s
59+
60+
@property
61+
def dim(self):
62+
return len(self[self.vocab[0]])
63+
64+
@property
65+
def unk_index(self):
66+
if self.unk is not None:
67+
return self.vocab[self.unk]
68+
raise AttributeError
69+
70+
@lazy_property
71+
def tokens(self):
72+
with open(self.path, 'r') as f:
73+
if self.skip_first:
74+
f.readline()
75+
return [line.split(self.split)[0] for line in progress_bar(f)]
76+
77+
@lazy_property
78+
def vectors(self):
79+
with open(self.path, 'r') as f:
80+
if self.skip_first:
81+
f.readline()
82+
return torch.tensor([list(map(float, line.strip().split(self.split)[1:])) for line in progress_bar(f)])
83+
84+
@classmethod
85+
def load(cls, path: str, unk: Optional[str] = None, **kwargs) -> Embedding:
86+
if path in PRETRAINED:
87+
cfg = dict(**PRETRAINED[path])
88+
embed = cfg.pop('_target_')
89+
return embed(**cfg, **kwargs)
90+
return cls(path, unk, **kwargs)
91+
92+
93+
class GloveEmbedding(Embedding):
94+
95+
def __init__(self, src: str = '6B', dim: int = 100, reload=False, *args, **kwargs) -> GloveEmbedding:
96+
if src == '6B' or src == 'twitter.27B':
97+
url = f'https://nlp.stanford.edu/data/glove.{src}.zip'
98+
else:
99+
url = f'https://nlp.stanford.edu/data/glove.{src}.{dim}d.zip'
100+
path = os.path.join(os.path.join(self.CACHE, 'glove'), f'glove.{src}.{dim}d.txt')
101+
if not os.path.exists(path) or reload:
102+
download(url, os.path.join(self.CACHE, 'glove'), clean=True)
103+
104+
super().__init__(path=path, unk='unk', *args, **kwargs, )
105+
106+
107+
class FasttextEmbedding(Embedding):
108+
109+
def __init__(self, lang: str = 'en', reload=False, *args, **kwargs) -> FasttextEmbedding:
110+
url = f'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{lang}.300.vec.gz'
111+
path = os.path.join(self.CACHE, 'fasttext', f'cc.{lang}.300.vec')
112+
if not os.path.exists(path) or reload:
113+
download(url, os.path.join(self.CACHE, 'fasttext'), clean=True)
114+
115+
super().__init__(path=path, skip_first=True, *args, **kwargs)
116+
117+
118+
class GigaEmbedding(Embedding):
119+
120+
def __init__(self, reload=False, *args, **kwargs) -> GigaEmbedding:
121+
url = 'https://github.com/yzhangcs/parser/releases/download/v1.1.0/giga.100.zip'
122+
path = os.path.join(self.CACHE, 'giga', 'giga.100.txt')
123+
if not os.path.exists(path) or reload:
124+
download(url, os.path.join(self.CACHE, 'giga'), clean=True)
125+
126+
super().__init__(path=path, *args, **kwargs)
127+
128+
129+
class TencentEmbedding(Embedding):
130+
131+
def __init__(self, dim: int = 100, big: bool = False, reload=False, *args, **kwargs) -> TencentEmbedding:
132+
url = f'https://ai.tencent.com/ailab/nlp/zh/data/tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if big else "-s"}.tar.gz' # noqa
133+
name = f'tencent-ailab-embedding-zh-d{dim}-v0.2.0{"" if big else "-s"}'
134+
path = os.path.join(os.path.join(self.CACHE, 'tencent'), name, f'{name}.txt')
135+
if not os.path.exists(path) or reload:
136+
download(url, os.path.join(self.CACHE, 'tencent'), clean=True)
137+
138+
super().__init__(path=path, skip_first=True, *args, **kwargs)
139+
140+
141+
PRETRAINED = {
142+
'glove-6b-50': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 50},
143+
'glove-6b-100': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 100},
144+
'glove-6b-200': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 200},
145+
'glove-6b-300': {'_target_': GloveEmbedding, 'src': '6B', 'dim': 300},
146+
'glove-42b-300': {'_target_': GloveEmbedding, 'src': '42B', 'dim': 300},
147+
'glove-840b-300': {'_target_': GloveEmbedding, 'src': '84B', 'dim': 300},
148+
'glove-twitter-27b-25': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 25},
149+
'glove-twitter-27b-50': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 50},
150+
'glove-twitter-27b-100': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 100},
151+
'glove-twitter-27b-200': {'_target_': GloveEmbedding, 'src': 'twitter.27B', 'dim': 200},
152+
'fasttext-bg': {'_target_': FasttextEmbedding, 'lang': 'bg'},
153+
'fasttext-ca': {'_target_': FasttextEmbedding, 'lang': 'ca'},
154+
'fasttext-cs': {'_target_': FasttextEmbedding, 'lang': 'cs'},
155+
'fasttext-de': {'_target_': FasttextEmbedding, 'lang': 'de'},
156+
'fasttext-en': {'_target_': FasttextEmbedding, 'lang': 'en'},
157+
'fasttext-es': {'_target_': FasttextEmbedding, 'lang': 'es'},
158+
'fasttext-fr': {'_target_': FasttextEmbedding, 'lang': 'fr'},
159+
'fasttext-it': {'_target_': FasttextEmbedding, 'lang': 'it'},
160+
'fasttext-nl': {'_target_': FasttextEmbedding, 'lang': 'nl'},
161+
'fasttext-no': {'_target_': FasttextEmbedding, 'lang': 'no'},
162+
'fasttext-ro': {'_target_': FasttextEmbedding, 'lang': 'ro'},
163+
'fasttext-ru': {'_target_': FasttextEmbedding, 'lang': 'ru'},
164+
'giga-100': {'_target_': GigaEmbedding},
165+
'tencent-100': {'_target_': TencentEmbedding, 'dim': 100},
166+
'tencent-100-b': {'_target_': TencentEmbedding, 'dim': 100, 'big': True},
167+
'tencent-200': {'_target_': TencentEmbedding, 'dim': 200},
168+
'tencent-200-b': {'_target_': TencentEmbedding, 'dim': 200, 'big': True},
169+
}

supar/utils/embedding.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)