Skip to content

Commit e76c499

Browse files
committed
crf
1 parent 110bfce commit e76c499

File tree

7 files changed

+56
-97
lines changed

7 files changed

+56
-97
lines changed

neural_ner/crf.py

Lines changed: 24 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.nn as nn
55
from data_utils.sentence_utils import Constants
6-
6+
import numpy as np
77

88
def get_mask(lengths):
99
seq_lens = lengths.view(-1, 1)
@@ -16,40 +16,35 @@ def get_mask(lengths):
1616
class CRF_Loss(nn.Module):
1717
def __init__(self, tagset_size):
1818
super(CRF_Loss, self).__init__()
19-
self.start_tag_idx = tagset_size
20-
self.stop_tag_idx = tagset_size + 1
19+
self.start_tag = tagset_size
20+
self.end_tag = tagset_size + 1
2121
self.num_tags = tagset_size + 2
2222

23-
#transition from y_i-1 to y_i, T[y_i, y_j] = y_i <= y_j
24-
#+2 added for start and end indices
2523
self.transitions = nn.Parameter(torch.Tensor(self.num_tags, self.num_tags))
26-
nn.init.uniform_(self.transitions, -0.1, 0.1)
24+
nn.init.constant_(self.transitions, -np.log(self.num_tags))
2725

28-
#no transition to start_tag, not transition from end tag
29-
self.transitions.data[self.start_tag_idx, :] = -10000
30-
self.transitions.data[:, self.stop_tag_idx] = -10000
26+
self.transitions.data[self.end_tag, :] = -10000
27+
self.transitions.data[:, self.start_tag] = -10000
3128

3229
def get_log_p_z(self, emissions, mask, seq_len):
3330
log_alpha = emissions[:, 0].clone()
34-
log_alpha += self.transitions[self.start_tag_idx, : self.start_tag_idx].unsqueeze(0)
31+
log_alpha += self.transitions[self.start_tag, : self.start_tag].unsqueeze(0)
3532

3633
for idx in range(1, seq_len):
3734
broadcast_emissions = emissions[:, idx].unsqueeze(1)
38-
broadcast_transitions = self.transitions[
39-
: self.start_tag, : self.start_tag
40-
].unsqueeze(0)
35+
broadcast_transitions = self.transitions[ : self.start_tag, : self.start_tag].unsqueeze(0)
4136
broadcast_logprob = log_alpha.unsqueeze(2)
4237
score = broadcast_logprob + broadcast_emissions + broadcast_transitions
4338

4439
score = torch.logsumexp(score, 1)
45-
log_alpha = score * mask[:, idx].unsqueeze(1) + log_alpha.squeeze(1) * (
46-
1.0 - mask[:, idx].unsqueeze(1)
47-
)
40+
log_alpha = score * mask[:, idx].unsqueeze(1) + log_alpha.squeeze(1) * (1.0 - mask[:, idx].unsqueeze(1))
4841

49-
log_alpha += self.transitions[: self.start_tag, self.end_tag].unsqueeze(0)
42+
log_alpha += self.transitions[: self.start_tag, self.end_tag].unsqueeze(0)
5043
return torch.logsumexp(log_alpha.squeeze(1), 1)
5144

5245
def get_log_p_Y_X(self, emissions, mask, seq_len, tags):
46+
tags[tags < 0] = 0 # clone and then set
47+
5348
llh = self.transitions[self.start_tag, tags[:, 0]].unsqueeze(1)
5449
llh += emissions[:, 0, :].gather(1, tags[:, 0].view(-1, 1)) * mask[:, 0].unsqueeze(1)
5550

@@ -79,20 +74,14 @@ def log_likelihood(self, emissions, tags):
7974
def forward(self, emissions, tags):
8075
return self.log_likelihood(emissions, tags)
8176

82-
def inference(self, emissions, lengths):
83-
return self.viterbi_decode(emissions, lengths)
84-
8577
def viterbi_decode(self, emissions, lengths):
8678
mask = get_mask(lengths)
8779
seq_len = emissions.shape[1]
8880

8981
log_prob = emissions[:, 0].clone()
9082
log_prob += self.transitions[self.start_tag, : self.start_tag].unsqueeze(0)
9183

92-
93-
end_scores = log_prob + self.transitions[
94-
: self.start_tag, self.end_tag
95-
].unsqueeze(0)
84+
end_scores = log_prob + self.transitions[: self.start_tag, self.end_tag].unsqueeze(0)
9685

9786
best_scores_list = []
9887
best_scores_list.append(end_scores.unsqueeze(1))
@@ -101,20 +90,12 @@ def viterbi_decode(self, emissions, lengths):
10190

10291
for idx in range(1, seq_len):
10392
broadcast_emissions = emissions[:, idx].unsqueeze(1)
104-
broadcast_transmissions = self.transitions[
105-
: self.start_tag, : self.start_tag
106-
].unsqueeze(0)
93+
broadcast_transmissions = self.transitions[: self.start_tag, : self.start_tag].unsqueeze(0)
10794
broadcast_log_prob = log_prob.unsqueeze(2)
108-
10995
score = broadcast_emissions + broadcast_transmissions + broadcast_log_prob
110-
11196
max_scores, max_score_indices = torch.max(score, 1)
112-
11397
best_paths_list.append(max_score_indices.unsqueeze(1))
114-
115-
end_scores = max_scores + self.transitions[
116-
: self.start_tag, self.end_tag
117-
].unsqueeze(0)
98+
end_scores = max_scores + self.transitions[: self.start_tag, self.end_tag].unsqueeze(0)
11899

119100
best_scores_list.append(end_scores.unsqueeze(1))
120101
log_prob = max_scores
@@ -128,48 +109,25 @@ def viterbi_decode(self, emissions, lengths):
128109
padding_tensor = torch.tensor(Constants.PAD_ID).long()
129110

130111
labels = max_indices_from_scores[:, seq_len - 1]
131-
labels = self._mask_tensor(labels, 1.0 - mask[:, seq_len - 1], padding_tensor)
132-
112+
labels = torch.where(1.0 - mask[:, seq_len - 1], padding_tensor, labels)
133113
all_labels = labels.unsqueeze(1).long()
134114

135115
for idx in range(seq_len - 2, -1, -1):
136116
indices_for_lookup = all_labels[:, -1].clone()
137-
indices_for_lookup = torch.where(
138-
indices_for_lookup == self.ignore_index,
139-
valid_index_tensor,
140-
indices_for_lookup
141-
)
117+
indices_for_lookup = torch.where(indices_for_lookup == self.ignore_index, valid_index_tensor,
118+
indices_for_lookup)
142119

143-
indices_from_prev_pos = (
144-
best_paths[:, idx, :]
145-
.gather(1, indices_for_lookup.view(-1, 1).long())
146-
.squeeze(1)
147-
)
148-
indices_from_prev_pos = torch.where(
149-
(1.0 - mask[:, idx + 1]),
150-
padding_tensor,
151-
indices_from_prev_pos
152-
)
120+
indices_from_prev_pos = best_paths[:, idx, :].gather(1, indices_for_lookup.view(-1, 1).long()).squeeze(1)
121+
indices_from_prev_pos = torch.where((1.0 - mask[:, idx + 1]), padding_tensor, indices_from_prev_pos)
153122

154123
indices_from_max_scores = max_indices_from_scores[:, idx]
155-
indices_from_max_scores = torch.where(
156-
mask[:, idx + 1],
157-
padding_tensor,
158-
indices_from_max_scores
159-
)
124+
indices_from_max_scores = torch.where(mask[:, idx + 1], padding_tensor, indices_from_max_scores)
160125

161-
labels = torch.where(
162-
indices_from_max_scores == self.ignore_index,
163-
indices_from_prev_pos,
164-
indices_from_max_scores,
165-
)
126+
labels = torch.where(indices_from_max_scores == self.ignore_index, indices_from_prev_pos,
127+
indices_from_max_scores)
166128

167129
# Set to ignore_index if present state is not valid.
168-
labels = torch.where(
169-
(1 - mask[:, idx]),
170-
padding_tensor,
171-
labels
172-
)
130+
labels = torch.where((1 - mask[:, idx]),padding_tensor, labels)
173131
all_labels = torch.cat((all_labels, labels.view(-1, 1).long()), 1)
174132

175133
return best_scores, torch.flip(all_labels, [1])

neural_ner/data_utils/batcher.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import os
44
import numpy as np
5-
from utils import load_sentences, prepare_dataset
6-
from vocab import Vocab
7-
from sentence_utils import pad_items, pad_chars
5+
from .utils import load_sentences, prepare_dataset
6+
from .vocab import Vocab
7+
from .sentence_utils import pad_items, pad_chars
88
import torch
99

1010
class DatasetConll2003(object):
@@ -39,7 +39,7 @@ def get_data_file(data_type, config):
3939
def __iter__(self):
4040
return self
4141

42-
def next(self):
42+
def __next__(self):
4343
self.iterations += 1
4444

4545
if self.is_train and self.i >= len(self.data):

neural_ner/data_utils/sentence_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def pad_chars(items, max_word_len):
3232
for item in items:
3333
padding = [pad_id] * (max_length - len(item))
3434
padded_items.append(item + padding)
35-
for i in xrange(len(items), max_word_len):
35+
for i in range(len(items), max_word_len):
3636
padding = [pad_id] * max_length
3737
padded_items.append(padding)
3838
padded_items_len.append(1)

neural_ner/data_utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import codecs
2-
from sentence_utils import prepare_sentence
3-
from tag_scheme_utils import update_tag_scheme
2+
from .sentence_utils import prepare_sentence
3+
from .tag_scheme_utils import update_tag_scheme
44

55
def load_sentences(path, tag_scheme):
66
sentences = []

neural_ner/data_utils/vocab.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import re
66
import numpy as np
77

8-
from sentence_utils import get_char_word_seq, Constants
9-
from utils import load_sentences
8+
from .sentence_utils import get_char_word_seq, Constants
9+
from .utils import load_sentences
1010
from collections import Counter
1111

1212
def create_freq_map(item_list):
@@ -39,9 +39,9 @@ def word_mapping(self, sentences):
3939
start_vocab_len = len(Constants._START_VOCAB)
4040
word_freq_map = create_freq_map(words)
4141
self.orig_word_freq_map = word_freq_map.copy()
42-
print "Found %i unique words (%i in total)" % (
42+
print ("Found {} unique words ({} in total)".format(
4343
len(word_freq_map), sum(len(x) for x in words)
44-
)
44+
))
4545
'''
4646
self.config.vocab_size = min(self.config.vocab_size, len(word_freq_map))
4747
sorted_items = word_freq_map.most_common(self.config.vocab_size)
@@ -67,7 +67,7 @@ def word_mapping(self, sentences):
6767
for i, v in enumerate(Constants._START_VOCAB):
6868
id_to_char[i] = v
6969

70-
print "Found %i unique characters" % len(char_freq_map)
70+
print("Found {} unique characters".format(len(char_freq_map)))
7171

7272
self.char_to_id = {v: k for k, v in id_to_char.items()}
7373
self.id_to_char = id_to_char
@@ -76,13 +76,13 @@ def tag_mapping(self, sentences):
7676
tags = [[word[-1] for word in s] for s in sentences]
7777
freq_map = create_freq_map(tags)
7878
id_to_tag = {i: v for i, v in enumerate(freq_map)}
79-
print "Found %i unique named entity tags" % len(freq_map)
79+
print ("Found {} unique named entity tags" .format(len(freq_map)))
8080

8181
self.tag_to_id = {v: k for k, v in id_to_tag.items()}
8282
self.id_to_tag = id_to_tag
8383

8484
def get_glove(self):
85-
print "Loading GLoVE vectors from file: %s" % self.config.glove_path
85+
print ("Loading GLoVE vectors from file: {}".format(self.config.glove_path))
8686
vocab_size = int(4e5)
8787
word_to_vector = {}
8888

@@ -103,6 +103,6 @@ def get_glove(self):
103103
if __name__ == '__main__':
104104
from config import config
105105
vocab = Vocab(config)
106-
print len(vocab.word_to_id)
106+
print (len(vocab.word_to_id))
107107
#emb_matrix = vocab.get_word_embd()
108108
#print len(emb_matrix)

neural_ner/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_word_embd(vocab, config):
4747
word_emb_matrix = np.random.uniform(low=-1.0, high=1.0,
4848
size=(len(vocab.word_to_id), config.word_emdb_dim))
4949
pretrained_init = 0
50-
for w, wid in vocab.word_to_id.iteritems():
50+
for w, wid in vocab.word_to_id.items():
5151
if w in word_to_vector:
5252
word_emb_matrix[wid, :] = word_to_vector[w]
5353
pretrained_init += 1
@@ -65,7 +65,7 @@ def test_one_batch(batch, model):
6565
return logits, pred
6666

6767
def get_model(vocab, config, model_file_path, is_eval=False):
68-
model = NER_SOFTMAX_CHAR(vocab, config)
68+
model = NER_SOFTMAX_CHAR_CRF(vocab, config)
6969

7070
if is_eval:
7171
model = model.eval()
@@ -188,19 +188,20 @@ def __init__(self, vocab, config):
188188

189189
self.featurizer = NER_SOFTMAX_CHAR(vocab, config)
190190
self.crf = CRF_Loss(len(vocab.id_to_tag))
191+
self.config = config
191192

192193
def forward(self, batch):
193194
emissions = self.featurizer(batch)
194195
return emissions
195196

196-
def crf_loss(self, emissions, target, s_lens):
197-
loss = -1 * self.crf(emissions, target)
198-
loss = loss.squeeze(1).sum(dim=1) / s_lens.float()
197+
def crf_loss(self, emissions, target):
198+
a = self.crf(emissions, target)
199+
loss = -1 * a
199200
loss = loss.mean()
200201
return loss
201202

202203
def get_loss(self, logits, y, s_lens):
203-
loss = self.crf_loss(logits, y, s_lens)
204+
loss = self.crf_loss(logits, y)
204205
if self.config.is_l2_loss:
205206
loss += self.get_l2_loss()
206207
return loss

neural_ner/train_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,24 +108,24 @@ def get_metric(self, log_dir, is_cf=False):
108108
eval_lines = [l.rstrip() for l in codecs.open(scores_path, 'r', 'utf8')]
109109
if is_cf:
110110
for line in eval_lines:
111-
print line
111+
print (line)
112112

113113
# Confusion matrix with accuracy for each tag
114-
print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)).format(
114+
print (("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)) % (
115115
"ID", "NE", "Total",
116-
*([self.vocab.id_to_tag[i] for i in xrange(self.n_tags)] + ["Percent"])
117-
)
118-
for i in xrange(self.n_tags):
119-
print ("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)).format(
116+
*([self.vocab.id_to_tag[i] for i in range(self.n_tags)] + ["Percent"])
117+
))
118+
for i in range(self.n_tags):
119+
print (("{: >2}{: >7}{: >7}%s{: >9}" % ("{: >7}" * self.n_tags)) % (
120120
str(i), self.vocab.id_to_tag[i], str(self.count[i].sum()),
121-
*([self.count[i][j] for j in xrange(self.n_tags)] +
121+
*([self.count[i][j] for j in range(self.n_tags)] +
122122
["%.3f" % (self.count[i][i] * 100. / max(1, self.count[i].sum()))])
123-
)
123+
))
124124

125125
# Global accuracy
126-
print "%i/%i (%.5f%%)" % (
126+
print ("%i/%i (%.5f%%)" % (
127127
self.count.trace(), self.count.sum(), 100. * self.count.trace() / max(1, self.count.sum())
128-
)
128+
))
129129

130130
# F1 on all entities
131131
return float(eval_lines[1].strip().split()[-1])

0 commit comments

Comments
 (0)