@@ -76,8 +76,8 @@ def batch_update(self, batch, pred):
76
76
s_lengths = batch ['words_lens' ]
77
77
y = batch ['tags' ]
78
78
raw_sentence = batch ['raw_sentence' ]
79
- y_reals = y .data .numpy ()
80
- y_preds = pred .data .numpy ()
79
+ y_reals = y .cpu (). data .numpy ()
80
+ y_preds = pred .cpu (). data .numpy ()
81
81
for i , s_len in enumerate (s_lengths ):
82
82
r_tags = []
83
83
p_tags = []
@@ -94,7 +94,7 @@ def batch_update(self, batch, pred):
94
94
self .predictions .append (new_line )
95
95
self .predictions .append ("" )
96
96
97
- def get_metric (self , log_dir ):
97
+ def get_metric (self , log_dir , is_cf = False ):
98
98
# Write predictions to disk and run CoNLL script externally
99
99
eval_id = np .random .randint (1000000 , 2000000 )
100
100
output_path = os .path .join (log_dir , "eval.%i.output" % eval_id )
@@ -106,25 +106,26 @@ def get_metric(self, log_dir):
106
106
107
107
# CoNLL evaluation results
108
108
eval_lines = [l .rstrip () for l in codecs .open (scores_path , 'r' , 'utf8' )]
109
- for line in eval_lines :
110
- print line
111
-
112
- # Confusion matrix with accuracy for each tag
113
- print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self .n_tags )).format (
114
- "ID" , "NE" , "Total" ,
115
- * ([self .vocab .id_to_tag [i ] for i in xrange (self .n_tags )] + ["Percent" ])
116
- )
117
- for i in xrange (self .n_tags ):
109
+ if is_cf :
110
+ for line in eval_lines :
111
+ print line
112
+
113
+ # Confusion matrix with accuracy for each tag
118
114
print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self .n_tags )).format (
119
- str (i ), self .vocab .id_to_tag [i ], str (self .count [i ].sum ()),
120
- * ([self .count [i ][j ] for j in xrange (self .n_tags )] +
121
- ["%.3f" % (self .count [i ][i ] * 100. / max (1 , self .count [i ].sum ()))])
115
+ "ID" , "NE" , "Total" ,
116
+ * ([self .vocab .id_to_tag [i ] for i in xrange (self .n_tags )] + ["Percent" ])
117
+ )
118
+ for i in xrange (self .n_tags ):
119
+ print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self .n_tags )).format (
120
+ str (i ), self .vocab .id_to_tag [i ], str (self .count [i ].sum ()),
121
+ * ([self .count [i ][j ] for j in xrange (self .n_tags )] +
122
+ ["%.3f" % (self .count [i ][i ] * 100. / max (1 , self .count [i ].sum ()))])
122
123
)
123
124
124
- # Global accuracy
125
- print "%i/%i (%.5f%%)" % (
126
- self .count .trace (), self .count .sum (), 100. * self .count .trace () / max (1 , self .count .sum ())
127
- )
125
+ # Global accuracy
126
+ print "%i/%i (%.5f%%)" % (
127
+ self .count .trace (), self .count .sum (), 100. * self .count .trace () / max (1 , self .count .sum ())
128
+ )
128
129
129
130
# F1 on all entities
130
131
return float (eval_lines [1 ].strip ().split ()[- 1 ])
0 commit comments