Skip to content

Commit c6ca288

Browse files
committed
Max margin training
1 parent dd0696e commit c6ca288

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

supar/cmds/crf_con.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def main():
2626
subparser.add_argument('--unk', default='unk', help='unk token in pretrained embeddings')
2727
subparser.add_argument('--n-embed', default=100, type=int, help='dimension of embeddings')
2828
subparser.add_argument('--bert', default='bert-base-cased', help='which BERT model to use')
29+
subparser.add_argument('--loss', choices=['crf', 'max-margin'], default='crf', help='loss for global training')
2930
# evaluate
3031
subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
3132
subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')

supar/models/con.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def loss(self, s_span, s_label, charts, mask, mbr=True):
175175
s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
176176
Scores of all constituents.
177177
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
178-
Scores of all labels on each constituent.
178+
Scores of all constituent labels.
179179
charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
180180
The tensor of gold-standard labels. Positions without labels are filled with -1.
181181
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
@@ -190,8 +190,13 @@ def loss(self, s_span, s_label, charts, mask, mbr=True):
190190
"""
191191

192192
span_mask = charts.ge(0) & mask
193-
span_dist = CRFConstituency(s_span, mask)
194-
span_loss = -span_dist.log_prob(span_mask).sum() / mask[:, 0].sum()
193+
if self.args.loss == 'crf':
194+
span_dist = CRFConstituency(s_span, mask)
195+
span_loss = -span_dist.log_prob(span_mask).sum()
196+
elif self.args.loss == 'max-margin':
197+
span_dist = CRFConstituency(s_span + torch.full_like(s_span, 1) - span_mask.float(), mask)
198+
span_loss = span_dist.max.sum() - s_span[span_mask].sum()
199+
span_loss = span_loss / mask[:, 0].sum()
195200
span_probs = span_dist.marginals if mbr else s_span
196201
label_loss = self.criterion(s_label[span_mask], charts[span_mask])
197202
loss = span_loss + label_loss
@@ -204,7 +209,7 @@ def decode(self, s_span, s_label, mask):
204209
s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
205210
Scores of all constituents.
206211
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
207-
Scores of all labels on each constituent.
212+
Scores of all constituent labels.
208213
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
209214
The mask for covering the unpadded tokens in each chart.
210215
@@ -406,7 +411,7 @@ def loss(self, s_span, s_pair, s_label, charts, mask):
406411
s_pair (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
407412
Scores of second-order triples.
408413
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
409-
Scores of all labels on each constituent.
414+
Scores of all constituent labels.
410415
charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
411416
The tensor of gold-standard labels. Positions without labels are filled with -1.
412417
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
@@ -430,7 +435,7 @@ def decode(self, s_span, s_label, mask):
430435
s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
431436
Scores of all constituents.
432437
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
433-
Scores of all labels on each constituent.
438+
Scores of all constituent labels.
434439
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
435440
The mask for covering the unpadded tokens in each chart.
436441

supar/structs/crf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ def __repr__(self):
289289

290290
@lazy_property
291291
def argmax(self):
292-
return [sorted(i.nonzero().tolist(), key=lambda x:(x[0], -x[1])) for i in self.backward(self.max.sum())]
292+
return [sorted(torch.nonzero(i).tolist(), key=lambda x:(x[0], -x[1])) for i in self.backward(self.max.sum())]
293293

294294
def topk(self, k):
295-
return list(zip(*[[sorted(i.nonzero().tolist(), key=lambda x:(x[0], -x[1])) for i in self.backward(i)]
295+
return list(zip(*[[sorted(torch.nonzero(i).tolist(), key=lambda x:(x[0], -x[1])) for i in self.backward(i)]
296296
for i in self.kmax(k).sum(0)]))
297297

298298
def score(self, value):

0 commit comments

Comments
 (0)