Skip to content

Commit 7e1e37c

Browse files
committed
Save predictions by batch
1 parent 575e278 commit 7e1e37c

File tree

5 files changed

+40
-70
lines changed

5 files changed

+40
-70
lines changed

supar/parsers/const.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ def _evaluate(self, loader):
205205
def _predict(self, loader):
206206
self.model.eval()
207207

208-
preds = {'trees': [], 'probs': [] if self.args.prob else None}
209208
for batch in progress_bar(loader):
210209
words, *feats, trees = batch
211210
word_mask = words.ne(self.args.pad_index)[:, 1:]
@@ -215,12 +214,10 @@ def _predict(self, loader):
215214
s_span, s_label = self.model(words, feats)
216215
s_span = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span
217216
chart_preds = self.model.decode(s_span, s_label, mask)
218-
preds['trees'].extend([Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
219-
for tree, chart in zip(trees, chart_preds)])
217+
batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
218+
for tree, chart in zip(trees, chart_preds)]
220219
if self.args.prob:
221-
preds['probs'].extend([prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)])
222-
223-
return preds
220+
batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)]
224221

225222
@classmethod
226223
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
@@ -498,7 +495,6 @@ def _evaluate(self, loader):
498495
def _predict(self, loader):
499496
self.model.eval()
500497

501-
preds = {'trees': [], 'probs': [] if self.args.prob else None}
502498
for batch in progress_bar(loader):
503499
words, *feats, trees = batch
504500
word_mask = words.ne(self.args.pad_index)[:, 1:]
@@ -508,9 +504,7 @@ def _predict(self, loader):
508504
s_span, s_pair, s_label = self.model(words, feats)
509505
s_span = self.model.inference((s_span, s_pair), mask)
510506
chart_preds = self.model.decode(s_span, s_label, mask)
511-
preds['trees'].extend([Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
512-
for tree, chart in zip(trees, chart_preds)])
507+
batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
508+
for tree, chart in zip(trees, chart_preds)]
513509
if self.args.prob:
514-
preds['probs'].extend([prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)])
515-
516-
return preds
510+
batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)]

supar/parsers/dep.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def _evaluate(self, loader):
212212
def _predict(self, loader):
213213
self.model.eval()
214214

215-
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
216215
for batch in progress_bar(loader):
217216
words, texts, *feats = batch
218217
word_mask = words.ne(self.args.pad_index)
@@ -222,14 +221,10 @@ def _predict(self, loader):
222221
lens = mask.sum(1).tolist()
223222
s_arc, s_rel = self.model(words, feats)
224223
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
225-
preds['arcs'].extend(arc_preds[mask].split(lens))
226-
preds['rels'].extend(rel_preds[mask].split(lens))
224+
batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
225+
batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
227226
if self.args.prob:
228-
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.softmax(-1).unbind())])
229-
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
230-
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
231-
232-
return preds
227+
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.softmax(-1).unbind())]
233228

234229
@classmethod
235230
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
@@ -526,7 +521,6 @@ def _predict(self, loader):
526521
self.model.eval()
527522

528523
CRF = DependencyCRF if self.args.proj else MatrixTree
529-
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
530524
for batch in progress_bar(loader):
531525
words, _, *feats = batch
532526
word_mask = words.ne(self.args.pad_index)
@@ -538,15 +532,11 @@ def _predict(self, loader):
538532
s_arc = CRF(s_arc, lens).marginals if self.args.mbr else s_arc
539533
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
540534
lens = lens.tolist()
541-
preds['arcs'].extend(arc_preds[mask].split(lens))
542-
preds['rels'].extend(rel_preds[mask].split(lens))
535+
batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
536+
batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
543537
if self.args.prob:
544538
arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
545-
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())])
546-
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
547-
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
548-
549-
return preds
539+
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())]
550540

551541

552542
class CRF2oDependencyParser(BiaffineDependencyParser):
@@ -745,7 +735,6 @@ def _evaluate(self, loader):
745735
def _predict(self, loader):
746736
self.model.eval()
747737

748-
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
749738
for batch in progress_bar(loader):
750739
words, texts, *feats = batch
751740
word_mask = words.ne(self.args.pad_index)
@@ -757,15 +746,11 @@ def _predict(self, loader):
757746
s_arc, s_sib = Dependency2oCRF((s_arc, s_sib), lens).marginals if self.args.mbr else (s_arc, s_sib)
758747
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
759748
lens = lens.tolist()
760-
preds['arcs'].extend(arc_preds[mask].split(lens))
761-
preds['rels'].extend(rel_preds[mask].split(lens))
749+
batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
750+
batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
762751
if self.args.prob:
763752
arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
764-
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())])
765-
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
766-
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
767-
768-
return preds
753+
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())]
769754

770755
@classmethod
771756
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
@@ -1054,7 +1039,6 @@ def _evaluate(self, loader):
10541039
def _predict(self, loader):
10551040
self.model.eval()
10561041

1057-
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
10581042
for batch in progress_bar(loader):
10591043
words, texts, *feats = batch
10601044
word_mask = words.ne(self.args.pad_index)
@@ -1065,11 +1049,7 @@ def _predict(self, loader):
10651049
s_arc, s_sib, s_rel = self.model(words, feats)
10661050
s_arc = self.model.inference((s_arc, s_sib), mask)
10671051
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
1068-
preds['arcs'].extend(arc_preds[mask].split(lens))
1069-
preds['rels'].extend(rel_preds[mask].split(lens))
1052+
batch.arcs = [i.tolist() for i in arc_preds[mask].split(lens)]
1053+
batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
10701054
if self.args.prob:
1071-
preds['probs'].extend([prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.unbind())])
1072-
preds['arcs'] = [seq.tolist() for seq in preds['arcs']]
1073-
preds['rels'] = [self.REL.vocab[seq.tolist()] for seq in preds['rels']]
1074-
1075-
return preds
1055+
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.unbind())]

supar/parsers/parser.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,12 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
134134

135135
logger.info("Making predictions on the dataset")
136136
start = datetime.now()
137-
preds = self._predict(dataset.loader)
137+
self._predict(dataset.loader)
138138
elapsed = datetime.now() - start
139139

140-
for name, value in preds.items():
141-
setattr(dataset, name, value)
142140
if pred is not None and is_master():
143141
logger.info(f"Saving predicted results to {pred}")
144-
self.transform.save(pred, dataset.sentences)
142+
self.transform.save(pred, dataset)
145143
logger.info(f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s")
146144

147145
return dataset

supar/parsers/sdp.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def _evaluate(self, loader):
179179
def _predict(self, loader):
180180
self.model.eval()
181181

182-
preds = {'labels': [], 'probs': [] if self.args.prob else None}
183182
for batch in progress_bar(loader):
184183
words, *feats = batch
185184
word_mask = words.ne(self.args.pad_index)
@@ -189,13 +188,11 @@ def _predict(self, loader):
189188
lens = mask[:, 1].sum(-1).tolist()
190189
s_edge, s_label = self.model(words, feats)
191190
label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1)
192-
preds['labels'].extend(chart[1:i, :i].tolist() for i, chart in zip(lens, label_preds))
191+
batch.labels = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row]
192+
for row in chart[1:i, :i].tolist()])
193+
for i, chart in zip(lens, label_preds)]
193194
if self.args.prob:
194-
preds['probs'].extend([prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.softmax(-1).unbind())])
195-
preds['labels'] = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart])
196-
for chart in preds['labels']]
197-
198-
return preds
195+
batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.softmax(-1).unbind())]
199196

200197
@classmethod
201198
def build(cls, path, min_freq=7, fix_len=20, **kwargs):
@@ -459,7 +456,6 @@ def _evaluate(self, loader):
459456
def _predict(self, loader):
460457
self.model.eval()
461458

462-
preds = {'labels': [], 'probs': [] if self.args.prob else None}
463459
for batch in progress_bar(loader):
464460
words, *feats = batch
465461
word_mask = words.ne(self.args.pad_index)
@@ -470,10 +466,8 @@ def _predict(self, loader):
470466
s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats)
471467
s_edge = self.model.inference((s_edge, s_sib, s_cop, s_grd), mask)
472468
label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1)
473-
preds['labels'].extend(chart[1:i, :i].tolist() for i, chart in zip(lens, label_preds))
469+
batch.labels = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row]
470+
for row in chart[1:i, :i].tolist()])
471+
for i, chart in zip(lens, label_preds)]
474472
if self.args.prob:
475-
preds['probs'].extend([prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.unbind())])
476-
preds['labels'] = [CoNLL.build_relations([[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart])
477-
for chart in preds['labels']]
478-
479-
return preds
473+
batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.unbind())]

supar/utils/transform.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from supar.utils.tokenizer import Tokenizer
1212

1313
if TYPE_CHECKING:
14-
from supar.utils import Field
14+
from supar.utils import Dataset, Field
1515

1616
logger = get_logger(__name__)
1717

@@ -83,9 +83,10 @@ def src(self):
8383
def tgt(self):
8484
raise AttributeError
8585

86-
def save(self, path: str, sentences: List['Sentence']) -> None:
86+
def save(self, path: str, data: Dataset) -> None:
8787
with open(path, 'w') as f:
88-
f.write('\n'.join([str(i) for i in sentences]) + '\n')
88+
for i in data:
89+
f.write(str(i) + '\n')
8990

9091

9192
class CoNLL(Transform):
@@ -679,11 +680,14 @@ def __getitem__(self, index):
679680
def __getattr__(self, name):
680681
if name in self.__dict__:
681682
return self.__dict__[name]
682-
if name in self.fields:
683-
return self.fields[name]
684-
if hasattr(self.sentences[0], name):
685-
return [getattr(s, name) for s in self.sentences]
686-
raise AttributeError
683+
return [getattr(s, name) for s in self.sentences]
684+
685+
def __setattr__(self, name, value):
686+
if name not in ('sentences', 'fields', 'names'):
687+
for s, v in zip(self.sentences, value):
688+
setattr(s, name, v)
689+
else:
690+
self.__dict__[name] = value
687691

688692

689693
class Sentence(object):

0 commit comments

Comments
 (0)