Skip to content

Commit 274d1f3

Browse files
committed
Integrated tokenizer (yzhangcs#47)
1 parent 052de92 commit 274d1f3

File tree

10 files changed

+107
-69
lines changed

10 files changed

+107
-69
lines changed

setup.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
setup(
66
name='supar',
7-
version='1.0.1-a1',
7+
version='1.0.1',
88
author='Yu Zhang',
99
author_email='yzhang.cs@outlook.com',
1010
description='Syntactic Parsing Models',
@@ -22,14 +22,21 @@
2222
setup_requires=[
2323
'setuptools>=18.0',
2424
],
25-
install_requires=['torch>=1.7.0', 'transformers>=3.1.0', 'nltk'],
25+
install_requires=[
26+
'torch>=1.7.0',
27+
'transformers>=3.1.0',
28+
'nltk',
29+
'stanza',
30+
'dill'],
2631
entry_points={
2732
'console_scripts': [
2833
'biaffine-dependency=supar.cmds.biaffine_dependency:main',
2934
'crfnp-dependency=supar.cmds.crfnp_dependency:main',
3035
'crf-dependency=supar.cmds.crf_dependency:main',
3136
'crf2o-dependency=supar.cmds.crf2o_dependency:main',
32-
'crf-constituency=supar.cmds.crf_constituency:main'
37+
'crf-constituency=supar.cmds.crf_constituency:main',
38+
'biaffine-semantic-dependency=supar.cmds.biaffine_semantic_dependency:main',
39+
'vi-semantic-dependency=supar.cmds.vi_semantic_dependency:main'
3340
]
3441
},
3542
python_requires='>=3.6',

supar/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
'VISemanticDependencyParser',
1616
'Parser']
1717

18-
__version__ = '1.0.1-a1'
18+
__version__ = '1.0.1'
1919

2020
PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser,
2121
CRFNPDependencyParser,

supar/parsers/constituency.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, mbr=True,
104104

105105
return super().evaluate(**Config().update(locals()))
106106

107-
def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, mbr=True, verbose=True, **kwargs):
107+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, mbr=True, verbose=True, **kwargs):
108108
r"""
109109
Args:
110110
data (list[list] or str):
111111
The data for prediction, both a list of instances and filename are allowed.
112112
pred (str):
113113
If specified, the predicted results will be saved to the file. Default: ``None``.
114+
lang (str):
115+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
116+
``None`` if tokenization is not required.
117+
Default: ``en``.
114118
buckets (int):
115119
The number of buckets that sentences are assigned to. Default: 32.
116120
batch_size (int):
@@ -241,15 +245,15 @@ def build(cls, path,
241245
if args.feat == 'char':
242246
FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, eos=eos, fix_len=args.fix_len)
243247
elif args.feat == 'bert':
244-
from transformers import AutoTokenizer
248+
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
245249
tokenizer = AutoTokenizer.from_pretrained(args.bert)
246250
FEAT = SubwordField('bert',
247251
pad=tokenizer.pad_token,
248252
unk=tokenizer.unk_token,
249-
bos=tokenizer.cls_token or tokenizer.cls_token,
250-
eos=tokenizer.sep_token or tokenizer.sep_token,
253+
bos=tokenizer.bos_token or tokenizer.cls_token,
251254
fix_len=args.fix_len,
252-
tokenize=tokenizer.tokenize)
255+
tokenize=tokenizer.tokenize,
256+
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
253257
FEAT.vocab = tokenizer.get_vocab()
254258
else:
255259
FEAT = Field('tags', bos=bos, eos=eos)

supar/parsers/dependency.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,18 @@ def evaluate(self, data, buckets=8, batch_size=5000,
102102

103103
return super().evaluate(**Config().update(locals()))
104104

105-
def predict(self, data, pred=None, buckets=8, batch_size=5000,
105+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000,
106106
prob=False, tree=True, proj=False, verbose=True, **kwargs):
107107
r"""
108108
Args:
109109
data (list[list] or str):
110110
The data for prediction, both a list of instances and filename are allowed.
111111
pred (str):
112112
If specified, the predicted results will be saved to the file. Default: ``None``.
113+
lang (str):
114+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
115+
``None`` if tokenization is not required.
116+
Default: ``en``.
113117
buckets (int):
114118
The number of buckets that sentences are assigned to. Default: 32.
115119
batch_size (int):
@@ -248,14 +252,15 @@ def build(cls, path,
248252
if args.feat == 'char':
249253
FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len)
250254
elif args.feat == 'bert':
251-
from transformers import AutoTokenizer
255+
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
252256
tokenizer = AutoTokenizer.from_pretrained(args.bert)
253257
FEAT = SubwordField('bert',
254258
pad=tokenizer.pad_token,
255259
unk=tokenizer.unk_token,
256260
bos=tokenizer.bos_token or tokenizer.cls_token,
257261
fix_len=args.fix_len,
258-
tokenize=tokenizer.tokenize)
262+
tokenize=tokenizer.tokenize,
263+
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
259264
FEAT.vocab = tokenizer.get_vocab()
260265
else:
261266
FEAT = Field('tags', bos=bos)
@@ -372,14 +377,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,
372377

373378
return super().evaluate(**Config().update(locals()))
374379

375-
def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False,
380+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False,
376381
mbr=True, tree=True, proj=False, verbose=True, **kwargs):
377382
r"""
378383
Args:
379384
data (list[list] or str):
380385
The data for prediction, both a list of instances and filename are allowed.
381386
pred (str):
382387
If specified, the predicted results will be saved to the file. Default: ``None``.
388+
lang (str):
389+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
390+
``None`` if tokenization is not required.
391+
Default: ``en``.
383392
buckets (int):
384393
The number of buckets that sentences are assigned to. Default: 32.
385394
batch_size (int):
@@ -554,14 +563,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,
554563

555564
return super().evaluate(**Config().update(locals()))
556565

557-
def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False,
566+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False,
558567
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
559568
r"""
560569
Args:
561570
data (list[list] or str):
562571
The data for prediction, both a list of instances and filename are allowed.
563572
pred (str):
564573
If specified, the predicted results will be saved to the file. Default: ``None``.
574+
lang (str):
575+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
576+
``None`` if tokenization is not required.
577+
Default: ``en``.
565578
buckets (int):
566579
The number of buckets that sentences are assigned to. Default: 32.
567580
batch_size (int):
@@ -742,14 +755,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,
742755

743756
return super().evaluate(**Config().update(locals()))
744757

745-
def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False,
758+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False,
746759
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
747760
r"""
748761
Args:
749762
data (list[list] or str):
750763
The data for prediction, both a list of instances and filename are allowed.
751764
pred (str):
752765
If specified, the predicted results will be saved to the file. Default: ``None``.
766+
lang (str):
767+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
768+
``None`` if tokenization is not required.
769+
Default: ``en``.
753770
buckets (int):
754771
The number of buckets that sentences are assigned to. Default: 32.
755772
batch_size (int):

supar/parsers/parser.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from datetime import datetime, timedelta
55

6+
import dill
67
import supar
78
import torch
89
import torch.distributed as dist
@@ -96,7 +97,7 @@ def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):
9697

9798
return loss, metric
9899

99-
def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, **kwargs):
100+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, **kwargs):
100101
args = self.args.update(locals())
101102
init_logger(logger, verbose=args.verbose)
102103

@@ -105,7 +106,7 @@ def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, **kwa
105106
self.transform.append(Field('probs'))
106107

107108
logger.info("Loading the data")
108-
dataset = Dataset(self.transform, data)
109+
dataset = Dataset(self.transform, data, lang=lang)
109110
dataset.build(args.batch_size, args.buckets)
110111
logger.info(f"\n{dataset}")
111112

@@ -185,4 +186,4 @@ def save(self, path):
185186
'state_dict': state_dict,
186187
'pretrained': pretrained,
187188
'transform': self.transform}
188-
torch.save(state, path)
189+
torch.save(state, path, pickle_module=dill)

supar/parsers/semantic_dependency.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, verbose=True, **kwargs):
7979

8080
return super().evaluate(**Config().update(locals()))
8181

82-
def predict(self, data, pred=None, buckets=8, batch_size=5000, verbose=True, **kwargs):
82+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, verbose=True, **kwargs):
8383
r"""
8484
Args:
8585
data (list[list] or str):
8686
The data for prediction, both a list of instances and filename are allowed.
8787
pred (str):
8888
If specified, the predicted results will be saved to the file. Default: ``None``.
89+
lang (str):
90+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
91+
``None`` if tokenization is not required.
92+
Default: ``en``.
8993
buckets (int):
9094
The number of buckets that sentences are assigned to. Default: 32.
9195
batch_size (int):
@@ -219,14 +223,16 @@ def build(cls,
219223
if 'lemma' in args.feat:
220224
LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True)
221225
if 'bert' in args.feat:
222-
from transformers import AutoTokenizer
226+
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
223227
tokenizer = AutoTokenizer.from_pretrained(args.bert)
224228
BERT = SubwordField('bert',
225229
pad=tokenizer.pad_token,
226230
unk=tokenizer.unk_token,
227231
bos=tokenizer.bos_token or tokenizer.cls_token,
232+
eos=tokenizer.eos_token or tokenizer.sep_token,
228233
fix_len=args.fix_len,
229-
tokenize=tokenizer.tokenize)
234+
tokenize=tokenizer.tokenize,
235+
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
230236
BERT.vocab = tokenizer.get_vocab()
231237
EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges)
232238
LABEL = ChartField('labels', fn=CoNLL.get_labels)
@@ -324,13 +330,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, verbose=True, **kwargs):
324330

325331
return super().evaluate(**Config().update(locals()))
326332

327-
def predict(self, data, pred=None, buckets=8, batch_size=5000, verbose=True, **kwargs):
333+
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, verbose=True, **kwargs):
328334
r"""
329335
Args:
330336
data (list[list] or str):
331337
The data for prediction, both a list of instances and filename are allowed.
332338
pred (str):
333339
If specified, the predicted results will be saved to the file. Default: ``None``.
340+
lang (str):
341+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
342+
``None`` if tokenization is not required.
343+
Default: ``en``.
334344
buckets (int):
335345
The number of buckets that sentences are assigned to. Default: 32.
336346
batch_size (int):
@@ -465,14 +475,16 @@ def build(cls,
465475
if 'lemma' in args.feat:
466476
LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True)
467477
if 'bert' in args.feat:
468-
from transformers import AutoTokenizer
478+
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
469479
tokenizer = AutoTokenizer.from_pretrained(args.bert)
470480
BERT = SubwordField('bert',
471481
pad=tokenizer.pad_token,
472482
unk=tokenizer.unk_token,
473483
bos=tokenizer.bos_token or tokenizer.cls_token,
484+
eos=tokenizer.eos_token or tokenizer.sep_token,
474485
fix_len=args.fix_len,
475-
tokenize=tokenizer.tokenize)
486+
tokenize=tokenizer.tokenize,
487+
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
476488
BERT.vocab = tokenizer.get_vocab()
477489
EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges)
478490
LABEL = ChartField('labels', fn=CoNLL.get_labels)

supar/utils/field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def transform(self, sequences):
319319
self.fix_len = max(len(token) for seq in sequences for token in seq)
320320
if self.use_vocab:
321321
sequences = [[[self.vocab[i] if i in self.vocab else self.unk_index for i in token] if token else [self.unk_index]
322-
for token in seq] for seq in sequences]
322+
for token in seq] for seq in sequences]
323323
if self.bos:
324324
sequences = [[[self.bos_index]] + seq for seq in sequences]
325325
if self.eos:

supar/utils/tokenizer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
class Tokenizer:
5+
6+
def __init__(self, lang='en'):
7+
import stanza
8+
try:
9+
self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', tokenize_no_ssplit=True)
10+
except Exception:
11+
stanza.download(lang=lang, resources_url='stanford')
12+
self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', tokenize_no_ssplit=True)
13+
14+
def __call__(self, text):
15+
return [i.text for i in self.pipeline(text).sentences[0].tokens]

supar/utils/transform.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# -*- coding: utf-8 -*-
22

3+
import os
34
from collections.abc import Iterable
45

56
import nltk
67
from supar.utils.logging import get_logger, progress_bar
8+
from supar.utils.tokenizer import Tokenizer
79

810
logger = get_logger(__name__)
911

@@ -343,14 +345,18 @@ def istree(cls, sequence, proj=False, multiroot=False):
343345
return False
344346
return next(tarjan(sequence), None) is None
345347

346-
def load(self, data, proj=False, max_len=None, **kwargs):
348+
def load(self, data, lang='en', proj=False, max_len=None, **kwargs):
347349
r"""
348350
Loads the data in CoNLL-X format.
349351
Also supports for loading data from CoNLL-U file with comments and non-integer IDs.
350352
351353
Args:
352354
data (list[list] or str):
353355
A list of instances or a filename.
356+
lang (str):
357+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
358+
``None`` if tokenization is not required.
359+
Default: ``en``.
354360
proj (bool):
355361
If ``True``, discards all non-projective sentences. Default: ``False``.
356362
max_len (int):
@@ -360,11 +366,15 @@ def load(self, data, proj=False, max_len=None, **kwargs):
360366
A list of :class:`CoNLLSentence` instances.
361367
"""
362368

363-
if isinstance(data, str):
369+
if isinstance(data, str) and os.path.exists(data):
364370
with open(data, 'r') as f:
365371
lines = [line.strip() for line in f]
366372
else:
367-
data = [data] if isinstance(data[0], str) else data
373+
if lang is not None:
374+
tokenizer = Tokenizer(lang)
375+
data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
376+
else:
377+
data = [data] if isinstance(data[0], str) else data
368378
lines = '\n'.join([self.toconll(i) for i in data]).split('\n')
369379

370380
i, start, sentences = 0, 0, []
@@ -680,23 +690,31 @@ def track(node):
680690
return [tree]
681691
return nltk.Tree(root, track(iter(sequence)))
682692

683-
def load(self, data, max_len=None, **kwargs):
693+
def load(self, data, lang='en', max_len=None, **kwargs):
684694
r"""
685695
Args:
686696
data (list[list] or str):
687697
A list of instances or a filename.
698+
lang (str):
699+
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
700+
``None`` if tokenization is not required.
701+
Default: ``en``.
688702
max_len (int):
689703
Sentences exceeding the length will be discarded. Default: ``None``.
690704
691705
Returns:
692706
A list of :class:`TreeSentence` instances.
693707
"""
694-
if isinstance(data, str):
708+
if isinstance(data, str) and os.path.exists(data):
695709
with open(data, 'r') as f:
696710
trees = [nltk.Tree.fromstring(string) for string in f]
697711
self.root = trees[0].label()
698712
else:
699-
data = [data] if isinstance(data[0], str) else data
713+
if lang is not None:
714+
tokenizer = Tokenizer(lang)
715+
data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
716+
else:
717+
data = [data] if isinstance(data[0], str) else data
700718
trees = [self.totree(i, self.root) for i in data]
701719

702720
i, sentences = 0, []

0 commit comments

Comments
 (0)