@@ -212,7 +212,6 @@ def _evaluate(self, loader):
212
212
def _predict (self , loader ):
213
213
self .model .eval ()
214
214
215
- preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
216
215
for batch in progress_bar (loader ):
217
216
words , texts , * feats = batch
218
217
word_mask = words .ne (self .args .pad_index )
@@ -222,14 +221,10 @@ def _predict(self, loader):
222
221
lens = mask .sum (1 ).tolist ()
223
222
s_arc , s_rel = self .model (words , feats )
224
223
arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask , self .args .tree , self .args .proj )
225
- preds [ ' arcs' ]. extend ( arc_preds [mask ].split (lens ))
226
- preds [ ' rels' ]. extend ( rel_preds [mask ].split (lens ))
224
+ batch . arcs = [ i . tolist () for i in arc_preds [mask ].split (lens )]
225
+ batch . rels = [ self . REL . vocab [ i . tolist ()] for i in rel_preds [mask ].split (lens )]
227
226
if self .args .prob :
228
- preds ['probs' ].extend ([prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , s_arc .softmax (- 1 ).unbind ())])
229
- preds ['arcs' ] = [seq .tolist () for seq in preds ['arcs' ]]
230
- preds ['rels' ] = [self .REL .vocab [seq .tolist ()] for seq in preds ['rels' ]]
231
-
232
- return preds
227
+ batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , s_arc .softmax (- 1 ).unbind ())]
233
228
234
229
@classmethod
235
230
def build (cls , path , min_freq = 2 , fix_len = 20 , ** kwargs ):
@@ -526,7 +521,6 @@ def _predict(self, loader):
526
521
self .model .eval ()
527
522
528
523
CRF = DependencyCRF if self .args .proj else MatrixTree
529
- preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
530
524
for batch in progress_bar (loader ):
531
525
words , _ , * feats = batch
532
526
word_mask = words .ne (self .args .pad_index )
@@ -538,15 +532,11 @@ def _predict(self, loader):
538
532
s_arc = CRF (s_arc , lens ).marginals if self .args .mbr else s_arc
539
533
arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask , self .args .tree , self .args .proj )
540
534
lens = lens .tolist ()
541
- preds [ ' arcs' ]. extend ( arc_preds [mask ].split (lens ))
542
- preds [ ' rels' ]. extend ( rel_preds [mask ].split (lens ))
535
+ batch . arcs = [ i . tolist () for i in arc_preds [mask ].split (lens )]
536
+ batch . rels = [ self . REL . vocab [ i . tolist ()] for i in rel_preds [mask ].split (lens )]
543
537
if self .args .prob :
544
538
arc_probs = s_arc if self .args .mbr else s_arc .softmax (- 1 )
545
- preds ['probs' ].extend ([prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , arc_probs .unbind ())])
546
- preds ['arcs' ] = [seq .tolist () for seq in preds ['arcs' ]]
547
- preds ['rels' ] = [self .REL .vocab [seq .tolist ()] for seq in preds ['rels' ]]
548
-
549
- return preds
539
+ batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , arc_probs .unbind ())]
550
540
551
541
552
542
class CRF2oDependencyParser (BiaffineDependencyParser ):
@@ -745,7 +735,6 @@ def _evaluate(self, loader):
745
735
def _predict (self , loader ):
746
736
self .model .eval ()
747
737
748
- preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
749
738
for batch in progress_bar (loader ):
750
739
words , texts , * feats = batch
751
740
word_mask = words .ne (self .args .pad_index )
@@ -757,15 +746,11 @@ def _predict(self, loader):
757
746
s_arc , s_sib = Dependency2oCRF ((s_arc , s_sib ), lens ).marginals if self .args .mbr else (s_arc , s_sib )
758
747
arc_preds , rel_preds = self .model .decode (s_arc , s_sib , s_rel , mask , self .args .tree , self .args .mbr , self .args .proj )
759
748
lens = lens .tolist ()
760
- preds [ ' arcs' ]. extend ( arc_preds [mask ].split (lens ))
761
- preds [ ' rels' ]. extend ( rel_preds [mask ].split (lens ))
749
+ batch . arcs = [ i . tolist () for i in arc_preds [mask ].split (lens )]
750
+ batch . rels = [ self . REL . vocab [ i . tolist ()] for i in rel_preds [mask ].split (lens )]
762
751
if self .args .prob :
763
752
arc_probs = s_arc if self .args .mbr else s_arc .softmax (- 1 )
764
- preds ['probs' ].extend ([prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , arc_probs .unbind ())])
765
- preds ['arcs' ] = [seq .tolist () for seq in preds ['arcs' ]]
766
- preds ['rels' ] = [self .REL .vocab [seq .tolist ()] for seq in preds ['rels' ]]
767
-
768
- return preds
753
+ batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , arc_probs .unbind ())]
769
754
770
755
@classmethod
771
756
def build (cls , path , min_freq = 2 , fix_len = 20 , ** kwargs ):
@@ -1054,7 +1039,6 @@ def _evaluate(self, loader):
1054
1039
def _predict (self , loader ):
1055
1040
self .model .eval ()
1056
1041
1057
- preds = {'arcs' : [], 'rels' : [], 'probs' : [] if self .args .prob else None }
1058
1042
for batch in progress_bar (loader ):
1059
1043
words , texts , * feats = batch
1060
1044
word_mask = words .ne (self .args .pad_index )
@@ -1065,11 +1049,7 @@ def _predict(self, loader):
1065
1049
s_arc , s_sib , s_rel = self .model (words , feats )
1066
1050
s_arc = self .model .inference ((s_arc , s_sib ), mask )
1067
1051
arc_preds , rel_preds = self .model .decode (s_arc , s_rel , mask , self .args .tree , self .args .proj )
1068
- preds [ ' arcs' ]. extend ( arc_preds [mask ].split (lens ))
1069
- preds [ ' rels' ]. extend ( rel_preds [mask ].split (lens ))
1052
+ batch . arcs = [ i . tolist () for i in arc_preds [mask ].split (lens )]
1053
+ batch . rels = [ self . REL . vocab [ i . tolist ()] for i in rel_preds [mask ].split (lens )]
1070
1054
if self .args .prob :
1071
- preds ['probs' ].extend ([prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , s_arc .unbind ())])
1072
- preds ['arcs' ] = [seq .tolist () for seq in preds ['arcs' ]]
1073
- preds ['rels' ] = [self .REL .vocab [seq .tolist ()] for seq in preds ['rels' ]]
1074
-
1075
- return preds
1055
+ batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , s_arc .unbind ())]
0 commit comments