Skip to content

Commit 4830efe

Browse files
committed
Fix a bug
1 parent d52d71c commit 4830efe

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

supar/models/dep.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,13 +558,15 @@ def loss(self, s_arc, s_sib, s_rel, arcs, sibs, rels, mask, mbr=True, partial=Fa
558558
original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
559559
"""
560560

561+
batch_size, seq_len = mask.shape
561562
if self.args.loss == 'crf':
562563
arc_dist = CRF2oDependency((s_arc, s_sib), mask, partial=partial)
563564
arc_loss = -arc_dist.log_prob((arcs, sibs)).sum() / mask.sum()
564565
elif self.args.loss == 'max-margin':
565566
s_arc = s_arc + torch.full_like(s_arc, 1).scatter_(-1, arcs.unsqueeze(-1), 0)
567+
s_sib = s_sib + torch.full_like(s_sib, 1).masked_fill_(sibs.unsqueeze(-1).eq(sibs.new_tensor(range(seq_len))), 0)
566568
arc_dist = CRF2oDependency((s_arc, s_sib), mask, partial=partial)
567-
arc_loss = (arc_dist.max - arc_dist.score(arcs)).sum() / mask.sum()
569+
arc_loss = (arc_dist.max - arc_dist.score((arcs, sibs))).sum() / mask.sum()
568570
arc_probs = arc_dist.marginals if mbr else s_arc
569571
# -1 denotes un-annotated arcs
570572
if partial:

0 commit comments

Comments
 (0)