Skip to content

Commit aa83656

Browse files
committed
Fix bug of incorrect masking for ConstituencyCRF
1 parent 444b515 commit aa83656

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

supar/models/const.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def loss(self, s_span, s_label, charts, mask, mbr=True):
190190
"""
191191

192192
span_mask = charts.ge(0) & mask
193-
span_dist = ConstituencyCRF(s_span, mask[:, 0].sum())
193+
span_dist = ConstituencyCRF(s_span, mask[:, 0].sum(-1))
194194
span_loss = -span_dist.log_prob(span_mask).sum() / mask[:, 0].sum()
195195
span_probs = span_dist.marginals if mbr else s_span
196196
label_loss = self.criterion(s_label[span_mask], charts[span_mask])
@@ -213,7 +213,7 @@ def decode(self, s_span, s_label, mask):
213213
Sequences of factorized labeled trees traversed in pre-order.
214214
"""
215215

216-
span_preds = ConstituencyCRF(s_span, mask[:, 0].sum()).argmax
216+
span_preds = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).argmax
217217
label_preds = s_label.argmax(-1).tolist()
218218
return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)]
219219

@@ -439,6 +439,6 @@ def decode(self, s_span, s_label, mask):
439439
Sequences of factorized labeled trees traversed in pre-order.
440440
"""
441441

442-
span_preds = ConstituencyCRF(s_span, mask[:, 0].sum()).argmax
442+
span_preds = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).argmax
443443
label_preds = s_label.argmax(-1).tolist()
444444
return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)]

supar/parsers/const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _predict(self, loader):
213213
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
214214
lens = mask[:, 0].sum(-1)
215215
s_span, s_label = self.model(words, feats)
216-
s_span = ConstituencyCRF(s_span, mask[:, 0].sum()).marginals if self.args.mbr else s_span
216+
s_span = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span
217217
chart_preds = self.model.decode(s_span, s_label, mask)
218218
preds['trees'].extend([Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
219219
for tree, chart in zip(trees, chart_preds)])

0 commit comments

Comments
 (0)