@@ -158,7 +158,7 @@ def _train(self, loader):
158
158
bar , metric = progress_bar (loader ), AttachmentMetric ()
159
159
160
160
for i , batch in enumerate (bar , 1 ):
161
- words , * feats , arcs , rels = batch
161
+ words , texts , * feats , arcs , rels = batch
162
162
word_mask = words .ne (self .args .pad_index )
163
163
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
164
164
# ignore the first token of each sentence
@@ -178,7 +178,7 @@ def _train(self, loader):
178
178
mask &= arcs .ge (0 )
179
179
# ignore all punctuation if not specified
180
180
if not self .args .punct :
181
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
181
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
182
182
metric (arc_preds , rel_preds , arcs , rels , mask )
183
183
bar .set_postfix_str (f"lr: { self .scheduler .get_last_lr ()[0 ]:.4e} - loss: { loss :.4f} - { metric } " )
184
184
logger .info (f"{ bar .postfix } " )
@@ -190,7 +190,7 @@ def _evaluate(self, loader):
190
190
total_loss , metric = 0 , AttachmentMetric ()
191
191
192
192
for batch in loader :
193
- words , * feats , arcs , rels = batch
193
+ words , texts , * feats , arcs , rels = batch
194
194
word_mask = words .ne (self .args .pad_index )
195
195
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
196
196
# ignore the first token of each sentence
@@ -202,7 +202,7 @@ def _evaluate(self, loader):
202
202
mask &= arcs .ge (0 )
203
203
# ignore all punctuation if not specified
204
204
if not self .args .punct :
205
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
205
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
206
206
total_loss += loss .item ()
207
207
metric (arc_preds , rel_preds , arcs , rels , mask )
208
208
total_loss /= len (loader )
@@ -215,7 +215,7 @@ def _predict(self, loader):
215
215
216
216
preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
217
217
for batch in progress_bar (loader ):
218
- words , * feats = batch
218
+ words , texts , * feats = batch
219
219
word_mask = words .ne (self .args .pad_index )
220
220
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
221
221
# ignore the first token of each sentence
@@ -470,7 +470,7 @@ def _train(self, loader):
470
470
bar , metric = progress_bar (loader ), AttachmentMetric ()
471
471
472
472
for i , batch in enumerate (bar , 1 ):
473
- words , * feats , arcs , rels = batch
473
+ words , texts , * feats , arcs , rels = batch
474
474
word_mask = words .ne (self .args .pad_index )
475
475
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
476
476
# ignore the first token of each sentence
@@ -490,7 +490,7 @@ def _train(self, loader):
490
490
mask &= arcs .ge (0 )
491
491
# ignore all punctuation if not specified
492
492
if not self .args .punct :
493
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
493
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
494
494
metric (arc_preds , rel_preds , arcs , rels , mask )
495
495
bar .set_postfix_str (f"lr: { self .scheduler .get_last_lr ()[0 ]:.4e} - loss: { loss :.4f} - { metric } " )
496
496
logger .info (f"{ bar .postfix } " )
@@ -502,7 +502,7 @@ def _evaluate(self, loader):
502
502
total_loss , metric = 0 , AttachmentMetric ()
503
503
504
504
for batch in loader :
505
- words , * feats , arcs , rels = batch
505
+ words , texts , * feats , arcs , rels = batch
506
506
word_mask = words .ne (self .args .pad_index )
507
507
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
508
508
# ignore the first token of each sentence
@@ -514,7 +514,7 @@ def _evaluate(self, loader):
514
514
mask &= arcs .ge (0 )
515
515
# ignore all punctuation if not specified
516
516
if not self .args .punct :
517
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
517
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
518
518
total_loss += loss .item ()
519
519
metric (arc_preds , rel_preds , arcs , rels , mask )
520
520
total_loss /= len (loader )
@@ -527,7 +527,7 @@ def _predict(self, loader):
527
527
528
528
preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
529
529
for batch in progress_bar (loader ):
530
- words , * feats = batch
530
+ words , texts , * feats = batch
531
531
word_mask = words .ne (self .args .pad_index )
532
532
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
533
533
# ignore the first token of each sentence
@@ -688,7 +688,7 @@ def _train(self, loader):
688
688
bar , metric = progress_bar (loader ), AttachmentMetric ()
689
689
690
690
for i , batch in enumerate (bar , 1 ):
691
- words , * feats , arcs , sibs , rels = batch
691
+ words , texts , * feats , arcs , sibs , rels = batch
692
692
word_mask = words .ne (self .args .pad_index )
693
693
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
694
694
# ignore the first token of each sentence
@@ -708,7 +708,7 @@ def _train(self, loader):
708
708
mask &= arcs .ge (0 )
709
709
# ignore all punctuation if not specified
710
710
if not self .args .punct :
711
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
711
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
712
712
metric (arc_preds , rel_preds , arcs , rels , mask )
713
713
bar .set_postfix_str (f"lr: { self .scheduler .get_last_lr ()[0 ]:.4e} - loss: { loss :.4f} - { metric } " )
714
714
logger .info (f"{ bar .postfix } " )
@@ -720,7 +720,7 @@ def _evaluate(self, loader):
720
720
total_loss , metric = 0 , AttachmentMetric ()
721
721
722
722
for batch in loader :
723
- words , * feats , arcs , sibs , rels = batch
723
+ words , texts , * feats , arcs , sibs , rels = batch
724
724
word_mask = words .ne (self .args .pad_index )
725
725
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
726
726
# ignore the first token of each sentence
@@ -732,7 +732,7 @@ def _evaluate(self, loader):
732
732
mask &= arcs .ge (0 )
733
733
# ignore all punctuation if not specified
734
734
if not self .args .punct :
735
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
735
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
736
736
total_loss += loss .item ()
737
737
metric (arc_preds , rel_preds , arcs , rels , mask )
738
738
total_loss /= len (loader )
@@ -745,7 +745,7 @@ def _predict(self, loader):
745
745
746
746
preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
747
747
for batch in progress_bar (loader ):
748
- words , * feats = batch
748
+ words , texts , * feats = batch
749
749
word_mask = words .ne (self .args .pad_index )
750
750
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
751
751
# ignore the first token of each sentence
@@ -995,7 +995,7 @@ def _train(self, loader):
995
995
bar , metric = progress_bar (loader ), AttachmentMetric ()
996
996
997
997
for i , batch in enumerate (bar , 1 ):
998
- words , * feats , arcs , rels = batch
998
+ words , texts , * feats , arcs , rels = batch
999
999
word_mask = words .ne (self .args .pad_index )
1000
1000
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
1001
1001
# ignore the first token of each sentence
@@ -1015,7 +1015,7 @@ def _train(self, loader):
1015
1015
mask &= arcs .ge (0 )
1016
1016
# ignore all punctuation if not specified
1017
1017
if not self .args .punct :
1018
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
1018
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
1019
1019
metric (arc_preds , rel_preds , arcs , rels , mask )
1020
1020
bar .set_postfix_str (f"lr: { self .scheduler .get_last_lr ()[0 ]:.4e} - loss: { loss :.4f} - { metric } " )
1021
1021
logger .info (f"{ bar .postfix } " )
@@ -1027,7 +1027,7 @@ def _evaluate(self, loader):
1027
1027
total_loss , metric = 0 , AttachmentMetric ()
1028
1028
1029
1029
for batch in loader :
1030
- words , * feats , arcs , rels = batch
1030
+ words , texts , * feats , arcs , rels = batch
1031
1031
word_mask = words .ne (self .args .pad_index )
1032
1032
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
1033
1033
# ignore the first token of each sentence
@@ -1039,7 +1039,7 @@ def _evaluate(self, loader):
1039
1039
mask &= arcs .ge (0 )
1040
1040
# ignore all punctuation if not specified
1041
1041
if not self .args .punct :
1042
- mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in batch . sentences for w in s . words ]))
1042
+ mask .masked_scatter_ (mask , ~ mask .new_tensor ([ispunct (w ) for s in texts for w in s ]))
1043
1043
total_loss += loss .item ()
1044
1044
metric (arc_preds , rel_preds , arcs , rels , mask )
1045
1045
total_loss /= len (loader )
@@ -1052,7 +1052,7 @@ def _predict(self, loader):
1052
1052
1053
1053
preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
1054
1054
for batch in progress_bar (loader ):
1055
- words , * feats = batch
1055
+ words , texts , * feats = batch
1056
1056
word_mask = words .ne (self .args .pad_index )
1057
1057
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
1058
1058
# ignore the first token of each sentence
0 commit comments