Skip to content

Commit 48ee4b1

Browse files
committed
refactor
1 parent 7d4b17e commit 48ee4b1

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

neural_ner/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch
33
import numpy as np
44

5+
print('pytorch version', torch.__version__)
6+
57
np.random.seed(123)
68
torch.manual_seed(123)
79
if torch.cuda.is_available():
@@ -54,6 +56,8 @@ class Config(object):
5456
config.model_name = 'model.NER_SOFTMAX_CHAR'
5557
config.optimizer = 'sgd'
5658

59+
config.use_pretrain_embd = True
60+
5761
# config postprocess
5862
config.is_cuda = config.is_cuda and torch.cuda.is_available()
5963

neural_ner/data_utils/vocab.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ def word_mapping(self, sentences):
4747
id_to_word = {i + start_vocab_len: v[0] for i, v in enumerate(sorted_items)}
4848
'''
4949
#augmnet with pretrained words
50-
self.glove_vectors = self.get_glove()
50+
if self.config.use_pretrain_embd:
51+
self.glove_vectors = self.get_glove()
52+
else:
53+
self.glove_vectors = {}
54+
5155
word_freq_map.update(self.glove_vectors.keys())
5256

5357
id_to_word = {}
@@ -102,6 +106,7 @@ def get_glove(self):
102106

103107
return word_to_vector
104108

109+
105110
if __name__ == '__main__':
106111
from config import config
107112
vocab = Vocab(config)

neural_ner/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from crf import CRF_Loss
1212
from model_utils import get_mask, init_lstm_wt, init_linear_wt, get_word_embd
1313

14-
print('pytorch version', torch.__version__)
15-
1614
logging.basicConfig(level=logging.INFO)
1715

1816
class NER_SOFTMAX_CHAR(nn.Module):

neural_ner/train_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,20 @@ def get_metric(self, log_dir, is_cf=False):
111111
print (line)
112112

113113
# Confusion matrix with accuracy for each tag
114-
print (("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)) % (
115-
"ID", "NE", "Total",
116-
*([self.vocab.id_to_tag[i] for i in range(self.n_tags)] + ["Percent"])
117-
))
114+
format_str = "{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)
115+
values = [self.vocab.id_to_tag[i] for i in range(self.n_tags)]
116+
print(format_str.format("ID", "NE", "Total", *values, "Percent"))
117+
118118
for i in range(self.n_tags):
119-
print (("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)) % (
120-
str(i), self.vocab.id_to_tag[i], str(self.count[i].sum()),
121-
*([self.count[i][j] for j in range(self.n_tags)] +
122-
["%.3f" % (self.count[i][i] * 100. / max(1, self.count[i].sum()))])
123-
))
119+
percent = "{:.3f}".format(self.count[i][i] * 100. / max(1, self.count[i].sum()))
120+
values = [self.count[i][j] for j in range(self.n_tags)]
121+
print (format_str.format(str(i), self.vocab.id_to_tag[i], str(self.count[i].sum()), *values, percent))
124122

125123
# Global accuracy
126-
print ("%i/%i (%.5f%%)" % (
124+
print ("{}/{} ({:.3f} %)" .format (
127125
self.count.trace(), self.count.sum(), 100. * self.count.trace() / max(1, self.count.sum())
128126
))
129127

130128
# F1 on all entities
131129
return float(eval_lines[1].strip().split()[-1])
132130

133-

0 commit comments

Comments
 (0)