Skip to content

Commit 59bf09c

Browse files
committed
Bug fix
1 parent 4830efe commit 59bf09c

File tree

6 files changed

+28
-34
lines changed

6 files changed

+28
-34
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
setup(
66
name='supar',
7-
version='1.1.1',
7+
version='1.1.2',
88
author='Yu Zhang',
99
author_email='yzhang.cs@outlook.com',
1010
description='Syntactic/Semantic Parsing Models',

supar/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
'LBPSemanticDependency',
3131
'MFVISemanticDependency']
3232

33-
__version__ = '1.1.1'
33+
__version__ = '1.1.2'
3434

3535
PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser,
3636
CRFDependencyParser,
@@ -51,6 +51,8 @@
5151
'biaffine-dep-roberta-en': 'ptb.biaffine.dep.roberta',
5252
'biaffine-dep-electra-zh': 'ctb7.biaffine.dep.electra',
5353
'biaffine-dep-xlmr': 'ud.biaffine.dep.xlmr',
54+
'mm-con-en': 'ptb.mm.con.lstm.char',
55+
# 'mm-con-zh': 'ctb7.mm.con.lstm.char',
5456
'crf-con-en': 'ptb.crf.con.lstm.char',
5557
'crf-con-zh': 'ctb7.crf.con.lstm.char',
5658
'crf-con-roberta-en': 'ptb.crf.con.roberta',
@@ -63,8 +65,7 @@
6365
'vi-sdp-roberta-en': 'dm.vi.sdp.roberta',
6466
'vi-sdp-electra-zh': 'semeval16.vi.sdp.electra'
6567
}
66-
MODEL = {n: f"{SRC['github']}/v1.1.0/{m}.zip" for n, m in NAME.items()}
67-
CONFIG = {n: f"{SRC['github']}/v1.1.0/{m}.ini" for n, m in NAME.items()}
68-
69-
MODEL['biaffine-sdp-en'] = f"{SRC['hlt']}/v1.1.1/{NAME['biaffine-sdp-en']}.zip"
70-
MODEL['biaffine-sdp-zh'] = f"{SRC['hlt']}/v1.1.1/{NAME['biaffine-sdp-zh']}.zip"
68+
MODEL = {n: f"{SRC['hlt']}/v1.1.0/{m}.zip" for n, m in NAME.items()}
69+
CONFIG = {n: f"{SRC['hlt']}/v1.1.0/{m}.ini" for n, m in NAME.items()}
70+
MODEL['biaffine-sdp-en'] = f"{SRC['hlt']}/v1.1.2/{NAME['biaffine-sdp-en']}.zip"
71+
MODEL['biaffine-sdp-zh'] = f"{SRC['hlt']}/v1.1.2/{NAME['biaffine-sdp-zh']}.zip"

supar/parsers/con.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ class CRFConstituencyParser(Parser):
2828
def __init__(self, *args, **kwargs):
2929
super().__init__(*args, **kwargs)
3030

31-
if self.args.feat in ('char', 'bert'):
32-
self.WORD, self.FEAT = self.transform.WORD
33-
else:
34-
self.WORD, self.FEAT = self.transform.WORD, self.transform.POS
3531
self.TREE = self.transform.TREE
3632
self.CHART = self.transform.CHART
3733

supar/parsers/dep.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _train(self, loader):
158158
bar, metric = progress_bar(loader), AttachmentMetric()
159159

160160
for i, batch in enumerate(bar, 1):
161-
words, *feats, arcs, rels = batch
161+
words, texts, *feats, arcs, rels = batch
162162
word_mask = words.ne(self.args.pad_index)
163163
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
164164
# ignore the first token of each sentence
@@ -178,7 +178,7 @@ def _train(self, loader):
178178
mask &= arcs.ge(0)
179179
# ignore all punctuation if not specified
180180
if not self.args.punct:
181-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
181+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
182182
metric(arc_preds, rel_preds, arcs, rels, mask)
183183
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
184184
logger.info(f"{bar.postfix}")
@@ -190,7 +190,7 @@ def _evaluate(self, loader):
190190
total_loss, metric = 0, AttachmentMetric()
191191

192192
for batch in loader:
193-
words, *feats, arcs, rels = batch
193+
words, texts, *feats, arcs, rels = batch
194194
word_mask = words.ne(self.args.pad_index)
195195
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
196196
# ignore the first token of each sentence
@@ -202,7 +202,7 @@ def _evaluate(self, loader):
202202
mask &= arcs.ge(0)
203203
# ignore all punctuation if not specified
204204
if not self.args.punct:
205-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
205+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
206206
total_loss += loss.item()
207207
metric(arc_preds, rel_preds, arcs, rels, mask)
208208
total_loss /= len(loader)
@@ -215,7 +215,7 @@ def _predict(self, loader):
215215

216216
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
217217
for batch in progress_bar(loader):
218-
words, *feats = batch
218+
words, texts, *feats = batch
219219
word_mask = words.ne(self.args.pad_index)
220220
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
221221
# ignore the first token of each sentence
@@ -470,7 +470,7 @@ def _train(self, loader):
470470
bar, metric = progress_bar(loader), AttachmentMetric()
471471

472472
for i, batch in enumerate(bar, 1):
473-
words, *feats, arcs, rels = batch
473+
words, texts, *feats, arcs, rels = batch
474474
word_mask = words.ne(self.args.pad_index)
475475
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
476476
# ignore the first token of each sentence
@@ -490,7 +490,7 @@ def _train(self, loader):
490490
mask &= arcs.ge(0)
491491
# ignore all punctuation if not specified
492492
if not self.args.punct:
493-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
493+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
494494
metric(arc_preds, rel_preds, arcs, rels, mask)
495495
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
496496
logger.info(f"{bar.postfix}")
@@ -502,7 +502,7 @@ def _evaluate(self, loader):
502502
total_loss, metric = 0, AttachmentMetric()
503503

504504
for batch in loader:
505-
words, *feats, arcs, rels = batch
505+
words, texts, *feats, arcs, rels = batch
506506
word_mask = words.ne(self.args.pad_index)
507507
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
508508
# ignore the first token of each sentence
@@ -514,7 +514,7 @@ def _evaluate(self, loader):
514514
mask &= arcs.ge(0)
515515
# ignore all punctuation if not specified
516516
if not self.args.punct:
517-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
517+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
518518
total_loss += loss.item()
519519
metric(arc_preds, rel_preds, arcs, rels, mask)
520520
total_loss /= len(loader)
@@ -527,7 +527,7 @@ def _predict(self, loader):
527527

528528
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
529529
for batch in progress_bar(loader):
530-
words, *feats = batch
530+
words, texts, *feats = batch
531531
word_mask = words.ne(self.args.pad_index)
532532
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
533533
# ignore the first token of each sentence
@@ -688,7 +688,7 @@ def _train(self, loader):
688688
bar, metric = progress_bar(loader), AttachmentMetric()
689689

690690
for i, batch in enumerate(bar, 1):
691-
words, *feats, arcs, sibs, rels = batch
691+
words, texts, *feats, arcs, sibs, rels = batch
692692
word_mask = words.ne(self.args.pad_index)
693693
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
694694
# ignore the first token of each sentence
@@ -708,7 +708,7 @@ def _train(self, loader):
708708
mask &= arcs.ge(0)
709709
# ignore all punctuation if not specified
710710
if not self.args.punct:
711-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
711+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
712712
metric(arc_preds, rel_preds, arcs, rels, mask)
713713
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
714714
logger.info(f"{bar.postfix}")
@@ -720,7 +720,7 @@ def _evaluate(self, loader):
720720
total_loss, metric = 0, AttachmentMetric()
721721

722722
for batch in loader:
723-
words, *feats, arcs, sibs, rels = batch
723+
words, texts, *feats, arcs, sibs, rels = batch
724724
word_mask = words.ne(self.args.pad_index)
725725
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
726726
# ignore the first token of each sentence
@@ -732,7 +732,7 @@ def _evaluate(self, loader):
732732
mask &= arcs.ge(0)
733733
# ignore all punctuation if not specified
734734
if not self.args.punct:
735-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
735+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
736736
total_loss += loss.item()
737737
metric(arc_preds, rel_preds, arcs, rels, mask)
738738
total_loss /= len(loader)
@@ -745,7 +745,7 @@ def _predict(self, loader):
745745

746746
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
747747
for batch in progress_bar(loader):
748-
words, *feats = batch
748+
words, texts, *feats = batch
749749
word_mask = words.ne(self.args.pad_index)
750750
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
751751
# ignore the first token of each sentence
@@ -995,7 +995,7 @@ def _train(self, loader):
995995
bar, metric = progress_bar(loader), AttachmentMetric()
996996

997997
for i, batch in enumerate(bar, 1):
998-
words, *feats, arcs, rels = batch
998+
words, texts, *feats, arcs, rels = batch
999999
word_mask = words.ne(self.args.pad_index)
10001000
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
10011001
# ignore the first token of each sentence
@@ -1015,7 +1015,7 @@ def _train(self, loader):
10151015
mask &= arcs.ge(0)
10161016
# ignore all punctuation if not specified
10171017
if not self.args.punct:
1018-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
1018+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
10191019
metric(arc_preds, rel_preds, arcs, rels, mask)
10201020
bar.set_postfix_str(f"lr: {self.scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}")
10211021
logger.info(f"{bar.postfix}")
@@ -1027,7 +1027,7 @@ def _evaluate(self, loader):
10271027
total_loss, metric = 0, AttachmentMetric()
10281028

10291029
for batch in loader:
1030-
words, *feats, arcs, rels = batch
1030+
words, texts, *feats, arcs, rels = batch
10311031
word_mask = words.ne(self.args.pad_index)
10321032
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
10331033
# ignore the first token of each sentence
@@ -1039,7 +1039,7 @@ def _evaluate(self, loader):
10391039
mask &= arcs.ge(0)
10401040
# ignore all punctuation if not specified
10411041
if not self.args.punct:
1042-
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in batch.sentences for w in s.words]))
1042+
mask.masked_scatter_(mask, ~mask.new_tensor([ispunct(w) for s in texts for w in s]))
10431043
total_loss += loss.item()
10441044
metric(arc_preds, rel_preds, arcs, rels, mask)
10451045
total_loss /= len(loader)
@@ -1052,7 +1052,7 @@ def _predict(self, loader):
10521052

10531053
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
10541054
for batch in progress_bar(loader):
1055-
words, *feats = batch
1055+
words, texts, *feats = batch
10561056
word_mask = words.ne(self.args.pad_index)
10571057
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
10581058
# ignore the first token of each sentence

supar/parsers/sdp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class BiaffineSemanticDependencyParser(Parser):
2828
def __init__(self, *args, **kwargs):
2929
super().__init__(*args, **kwargs)
3030

31-
self.WORD, self.CHAR, self.ELMO, self.BERT = self.transform.FORM
3231
self.LEMMA = self.transform.LEMMA
3332
self.TAG = self.transform.POS
3433
self.LABEL = self.transform.PHEAD
@@ -301,7 +300,6 @@ class VISemanticDependencyParser(BiaffineSemanticDependencyParser):
301300
def __init__(self, *args, **kwargs):
302301
super().__init__(*args, **kwargs)
303302

304-
self.WORD, self.CHAR, self.ELMO, self.BERT = self.transform.FORM
305303
self.LEMMA = self.transform.LEMMA
306304
self.TAG = self.transform.POS
307305
self.LABEL = self.transform.PHEAD

supar/parsers/srl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class VISemanticRoleLabelingParser(Parser):
2727
def __init__(self, *args, **kwargs):
2828
super().__init__(*args, **kwargs)
2929

30-
self.WORD, self.CHAR, self.ELMO, self.BERT = self.transform.FORM
3130
self.LEMMA = self.transform.LEMMA
3231
self.TAG = self.transform.POS
3332
self.LABEL = self.transform.PHEAD

0 commit comments

Comments
 (0)