Skip to content

Commit 34881ca

Browse files
committed
fix gpu run
1 parent c0aeecd commit 34881ca

File tree

2 files changed

+21
-20
lines changed

2 files changed

+21
-20
lines changed

neural_ner/process_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def evaluate(self, data_type, num_samples=None):
155155

156156
if __name__ == "__main__":
157157
from config import config
158-
config.gpu = config.gpu and torch.cuda.is_available()
158+
config.is_cuda = config.is_cuda and torch.cuda.is_available()
159159
mode = sys.argv[1]
160160
model_file_path = None
161161
if len(sys.argv) > 2:

neural_ner/train_utils.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def batch_update(self, batch, pred):
7676
s_lengths = batch['words_lens']
7777
y = batch['tags']
7878
raw_sentence = batch['raw_sentence']
79-
y_reals = y.data.numpy()
80-
y_preds = pred.data.numpy()
79+
y_reals = y.cpu().data.numpy()
80+
y_preds = pred.cpu().data.numpy()
8181
for i, s_len in enumerate(s_lengths):
8282
r_tags = []
8383
p_tags = []
@@ -94,7 +94,7 @@ def batch_update(self, batch, pred):
9494
self.predictions.append(new_line)
9595
self.predictions.append("")
9696

97-
def get_metric(self, log_dir):
97+
def get_metric(self, log_dir, is_cf=False):
9898
# Write predictions to disk and run CoNLL script externally
9999
eval_id = np.random.randint(1000000, 2000000)
100100
output_path = os.path.join(log_dir, "eval.%i.output" % eval_id)
@@ -106,25 +106,26 @@ def get_metric(self, log_dir):
106106

107107
# CoNLL evaluation results
108108
eval_lines = [l.rstrip() for l in codecs.open(scores_path, 'r', 'utf8')]
109-
for line in eval_lines:
110-
print line
111-
112-
# Confusion matrix with accuracy for each tag
113-
print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)).format(
114-
"ID", "NE", "Total",
115-
*([self.vocab.id_to_tag[i] for i in xrange(self.n_tags)] + ["Percent"])
116-
)
117-
for i in xrange(self.n_tags):
109+
if is_cf:
110+
for line in eval_lines:
111+
print line
112+
113+
# Confusion matrix with accuracy for each tag
118114
print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)).format(
119-
str(i), self.vocab.id_to_tag[i], str(self.count[i].sum()),
120-
*([self.count[i][j] for j in xrange(self.n_tags)] +
121-
["%.3f" % (self.count[i][i] * 100. / max(1, self.count[i].sum()))])
115+
"ID", "NE", "Total",
116+
*([self.vocab.id_to_tag[i] for i in xrange(self.n_tags)] + ["Percent"])
117+
)
118+
for i in xrange(self.n_tags):
119+
print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)).format(
120+
str(i), self.vocab.id_to_tag[i], str(self.count[i].sum()),
121+
*([self.count[i][j] for j in xrange(self.n_tags)] +
122+
["%.3f" % (self.count[i][i] * 100. / max(1, self.count[i].sum()))])
122123
)
123124

124-
# Global accuracy
125-
print "%i/%i (%.5f%%)" % (
126-
self.count.trace(), self.count.sum(), 100. * self.count.trace() / max(1, self.count.sum())
127-
)
125+
# Global accuracy
126+
print "%i/%i (%.5f%%)" % (
127+
self.count.trace(), self.count.sum(), 100. * self.count.trace() / max(1, self.count.sum())
128+
)
128129

129130
# F1 on all entities
130131
return float(eval_lines[1].strip().split()[-1])

0 commit comments

Comments
 (0)