Skip to content

Commit 0c3ec4e

Browse files
committed
Softmax margin training
1 parent 59bf09c commit 0c3ec4e

File tree

5 files changed

+17
-6
lines changed

5 files changed

+17
-6
lines changed

supar/cmds/crf2o_dep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def main():
3030
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
3131
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
3232
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
33-
subparser.add_argument('--loss', choices=['crf', 'max-margin'], default='crf', help='loss for global training')
33+
subparser.add_argument('--loss', choices=['crf', 'max-marg', 'softmax-marg'], default='crf', help='training criteria')
3434
# evaluate
3535
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
3636
subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')

supar/cmds/crf_con.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def main():
2626
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
2727
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
2828
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
29-
subparser.add_argument('--loss', choices=['crf', 'max-margin'], default='crf', help='loss for global training')
29+
subparser.add_argument('--loss', choices=['crf', 'max-marg', 'softmax-marg'], default='crf', help='training criteria')
3030
# evaluate
3131
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
3232
subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')

supar/cmds/crf_dep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def main():
3030
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
3131
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
3232
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
33-
subparser.add_argument('--loss', choices=['crf', 'max-margin'], default='crf', help='loss for global training')
33+
subparser.add_argument('--loss', choices=['crf', 'max-marg', 'softmax-marg'], default='crf', help='training criteria')
3434
# evaluate
3535
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
3636
subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')

supar/models/con.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,12 @@ def loss(self, s_span, s_label, charts, mask, mbr=True):
193193
if self.args.loss == 'crf':
194194
span_dist = CRFConstituency(s_span, mask)
195195
span_loss = -span_dist.log_prob(span_mask).sum() / mask[:, 0].sum()
196-
elif self.args.loss == 'max-margin':
196+
elif self.args.loss == 'max-marg':
197197
span_dist = CRFConstituency(s_span + torch.full_like(s_span, 1) - span_mask.float(), mask)
198198
span_loss = (span_dist.max - span_dist.score(span_mask)).sum() / mask[:, 0].sum()
199+
elif self.args.loss == 'softmax-marg':
200+
span_dist = CRFConstituency(s_span + torch.full_like(s_span, 1) - span_mask.float(), mask)
201+
span_loss = -span_dist.log_prob(span_mask).sum() / mask[:, 0].sum()
199202
span_probs = span_dist.marginals if mbr else s_span
200203
label_loss = self.criterion(s_label[span_mask], charts[span_mask])
201204
loss = span_loss + label_loss

supar/models/dep.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,12 @@ def loss(self, s_arc, s_rel, arcs, rels, mask, mbr=True, partial=False):
346346
if self.args.loss == 'crf':
347347
arc_dist = CRF(s_arc, mask, partial=partial)
348348
arc_loss = -arc_dist.log_prob(arcs).sum() / mask.sum()
349-
elif self.args.loss == 'max-margin':
349+
elif self.args.loss == 'max-marg':
350350
arc_dist = CRF(s_arc + torch.full_like(s_arc, 1).scatter_(-1, arcs.unsqueeze(-1), 0), mask, partial=partial)
351351
arc_loss = (arc_dist.max - arc_dist.score(arcs)).sum() / mask.sum()
352+
elif self.args.loss == 'softmax-marg':
353+
arc_dist = CRF(s_arc + torch.full_like(s_arc, 1).scatter_(-1, arcs.unsqueeze(-1), 0), mask, partial=partial)
354+
arc_loss = -arc_dist.log_prob(arcs).sum() / mask.sum()
352355
arc_probs = arc_dist.marginals if mbr else s_arc
353356
# -1 denotes un-annotated arcs
354357
if partial:
@@ -562,11 +565,16 @@ def loss(self, s_arc, s_sib, s_rel, arcs, sibs, rels, mask, mbr=True, partial=Fa
562565
if self.args.loss == 'crf':
563566
arc_dist = CRF2oDependency((s_arc, s_sib), mask, partial=partial)
564567
arc_loss = -arc_dist.log_prob((arcs, sibs)).sum() / mask.sum()
565-
elif self.args.loss == 'max-margin':
568+
elif self.args.loss == 'max-marg':
566569
s_arc = s_arc + torch.full_like(s_arc, 1).scatter_(-1, arcs.unsqueeze(-1), 0)
567570
s_sib = s_sib + torch.full_like(s_sib, 1).masked_fill_(sibs.unsqueeze(-1).eq(sibs.new_tensor(range(seq_len))), 0)
568571
arc_dist = CRF2oDependency((s_arc, s_sib), mask, partial=partial)
569572
arc_loss = (arc_dist.max - arc_dist.score((arcs, sibs))).sum() / mask.sum()
573+
elif self.args.loss == 'softmax-marg':
574+
s_arc = s_arc + torch.full_like(s_arc, 1).scatter_(-1, arcs.unsqueeze(-1), 0)
575+
s_sib = s_sib + torch.full_like(s_sib, 1).masked_fill_(sibs.unsqueeze(-1).eq(sibs.new_tensor(range(seq_len))), 0)
576+
arc_dist = CRF2oDependency((s_arc, s_sib), mask, partial=partial)
577+
arc_loss = -arc_dist.log_prob((arcs, sibs)).sum() / mask.sum()
570578
arc_probs = arc_dist.marginals if mbr else s_arc
571579
# -1 denotes un-annotated arcs
572580
if partial:

0 commit comments

Comments
 (0)