@@ -175,7 +175,7 @@ def loss(self, s_span, s_label, charts, mask, mbr=True):
175
175
s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
176
176
Scores of all constituents.
177
177
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
178
- Scores of all labels on each constituent.
178
+ Scores of all constituent labels .
179
179
charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
180
180
The tensor of gold-standard labels. Positions without labels are filled with -1.
181
181
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
@@ -190,8 +190,13 @@ def loss(self, s_span, s_label, charts, mask, mbr=True):
190
190
"""
191
191
192
192
span_mask = charts .ge (0 ) & mask
193
- span_dist = CRFConstituency (s_span , mask )
194
- span_loss = - span_dist .log_prob (span_mask ).sum () / mask [:, 0 ].sum ()
193
+ if self .args .loss == 'crf' :
194
+ span_dist = CRFConstituency (s_span , mask )
195
+ span_loss = - span_dist .log_prob (span_mask ).sum ()
196
+ elif self .args .loss == 'max-margin' :
197
+ span_dist = CRFConstituency (s_span + torch .full_like (s_span , 1 ) - span_mask .float (), mask )
198
+ span_loss = span_dist .max .sum () - s_span [span_mask ].sum ()
199
+ span_loss = span_loss / mask [:, 0 ].sum ()
195
200
span_probs = span_dist .marginals if mbr else s_span
196
201
label_loss = self .criterion (s_label [span_mask ], charts [span_mask ])
197
202
loss = span_loss + label_loss
@@ -204,7 +209,7 @@ def decode(self, s_span, s_label, mask):
204
209
s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
205
210
Scores of all constituents.
206
211
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
207
- Scores of all labels on each constituent.
212
+ Scores of all constituent labels .
208
213
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
209
214
The mask for covering the unpadded tokens in each chart.
210
215
@@ -406,7 +411,7 @@ def loss(self, s_span, s_pair, s_label, charts, mask):
406
411
s_pair (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
407
412
Scores of second-order triples.
408
413
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
409
- Scores of all labels on each constituent.
414
+ Scores of all constituent labels .
410
415
charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
411
416
The tensor of gold-standard labels. Positions without labels are filled with -1.
412
417
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
@@ -430,7 +435,7 @@ def decode(self, s_span, s_label, mask):
430
435
s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
431
436
Scores of all constituents.
432
437
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
433
- Scores of all labels on each constituent.
438
+ Scores of all constituent labels .
434
439
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
435
440
The mask for covering the unpadded tokens in each chart.
436
441
0 commit comments