@@ -190,7 +190,7 @@ def loss(self, s_span, s_label, charts, mask, mbr=True):
190
190
"""
191
191
192
192
span_mask = charts .ge (0 ) & mask
193
- span_dist = ConstituencyCRF (s_span , mask [:, 0 ].sum ())
193
+ span_dist = ConstituencyCRF (s_span , mask [:, 0 ].sum (- 1 ))
194
194
span_loss = - span_dist .log_prob (span_mask ).sum () / mask [:, 0 ].sum ()
195
195
span_probs = span_dist .marginals if mbr else s_span
196
196
label_loss = self .criterion (s_label [span_mask ], charts [span_mask ])
@@ -213,7 +213,7 @@ def decode(self, s_span, s_label, mask):
213
213
Sequences of factorized labeled trees traversed in pre-order.
214
214
"""
215
215
216
- span_preds = ConstituencyCRF (s_span , mask [:, 0 ].sum ()).argmax
216
+ span_preds = ConstituencyCRF (s_span , mask [:, 0 ].sum (- 1 )).argmax
217
217
label_preds = s_label .argmax (- 1 ).tolist ()
218
218
return [[(i , j , labels [i ][j ]) for i , j in spans ] for spans , labels in zip (span_preds , label_preds )]
219
219
@@ -439,6 +439,6 @@ def decode(self, s_span, s_label, mask):
439
439
Sequences of factorized labeled trees traversed in pre-order.
440
440
"""
441
441
442
- span_preds = ConstituencyCRF (s_span , mask [:, 0 ].sum ()).argmax
442
+ span_preds = ConstituencyCRF (s_span , mask [:, 0 ].sum (- 1 )).argmax
443
443
label_preds = s_label .argmax (- 1 ).tolist ()
444
444
return [[(i , j , labels [i ][j ]) for i , j in spans ] for spans , labels in zip (span_preds , label_preds )]
0 commit comments