Skip to content

Commit 110bfce

Browse files
committed
crf implementation
1 parent d8e3bde commit 110bfce

File tree

3 files changed

+45
-142
lines changed

3 files changed

+45
-142
lines changed

neural_ner/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,4 @@ class Config(object):
4242

4343
config.is_cuda = False
4444

45-
config.is_l2_loss = False
46-
47-
config.optimizer='sdg' #'adam'
45+
config.is_l2_loss = False

neural_ner/model.py

Lines changed: 40 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import logging
99
import numpy as np
1010

11+
from crf import CRF_Loss
12+
1113
print('pytorch version', torch.__version__)
1214

1315
logging.basicConfig(level=logging.INFO)
@@ -58,7 +60,8 @@ def get_word_embd(vocab, config):
5860
def test_one_batch(batch, model):
5961
model.eval()
6062
logits = model(batch)
61-
_, pred = model.get_argmax(logits)
63+
lengths = batch['words_lens']
64+
pred = model.predict(logits, lengths)
6265
return logits, pred
6366

6467
def get_model(vocab, config, model_file_path, is_eval=False):
@@ -80,7 +83,6 @@ def __init__(self, vocab, config):
8083
super(NER_SOFTMAX_CHAR, self).__init__()
8184
word_emb_matrix = get_word_embd(vocab, config)
8285
embd_vector = torch.from_numpy(word_emb_matrix).float()
83-
tagset_size = len(vocab.id_to_tag)
8486

8587
self.word_embeds = nn.Embedding.from_pretrained(embd_vector, freeze=False)
8688
self.char_embeds = nn.Embedding(len(vocab.char_to_id), config.char_embd_dim, padding_idx=Constants.PAD_ID)
@@ -97,7 +99,8 @@ def __init__(self, vocab, config):
9799
self.dropout = nn.Dropout(config.dropout_rate)
98100
self.hidden_layer = nn.Linear(config.word_lstm_dim * 2, config.word_lstm_dim)
99101
self.tanh_layer = torch.nn.Tanh()
100-
self.hidden2tag = nn.Linear(config.word_lstm_dim, tagset_size)
102+
103+
self.hidden2tag = nn.Linear(config.word_lstm_dim, len(vocab.id_to_tag))
101104

102105
self.config = config
103106

@@ -158,144 +161,50 @@ def forward(self, batch):
158161

159162
return logits
160163

161-
def neg_log_likelihood(self, logits, y, s_len):
164+
def neg_log_likelihood(self, logits, y, s_lens):
162165
log_smx = F.log_softmax(logits, dim=2)
163166
loss = F.nll_loss(log_smx.transpose(1, 2), y, ignore_index=Constants.TAG_PAD_ID, reduce=False)
164-
loss = loss.squeeze(1).sum(dim=1) / s_len.float()
167+
loss = loss.squeeze(1).sum(dim=1) / s_lens.float()
165168
loss = loss.mean()
166-
if self.config.is_l2_loss:
167-
l2_reg = sum(p.norm(2) for p in self.parameters() if p.requires_grad)
168-
loss += self.config.reg_lambda * l2_reg
169169
return loss
170170

171-
def get_argmax(self, logits):
172-
max_value, max_idx = torch.max(logits, dim=2)
173-
return max_value, max_idx
174-
175-
class NER_CRF(nn.Module):
176-
def __init__(self, embd_vector, hidden_dim, tagset_size,
177-
reg_lambda):
178-
super(NER_CRF, self).__init__()
179-
180-
self.start_tag_idx = tagset_size
181-
self.stop_tag_idx = tagset_size + 1
182-
self.all_tagset_size = tagset_size + 2
183-
184-
self.word_embeds = nn.Embedding.from_pretrained(embd_vector)
185-
embedding_dim = self.word_embeds.embedding_dim
186-
187-
self.lstm = nn.LSTM(embedding_dim, hidden_dim//2,
188-
num_layers=1, bidirectional=True, batch_first=True)
189-
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
190-
191-
#transition from y_i-1 to y_i, T[y_i, y_j] = y_i <= y_j
192-
#+2 added for start and end indices
193-
self.transitions = nn.Parameter(torch.randn(self.all_tagset_size, self.all_tagset_size))
194-
nn.init.uniform_(self.transitions, -0.1, 0.1)
195-
196-
#no transition to start_tag, not transition from end tag
197-
self.transitions.data[self.start_tag_idx, :] = -10000
198-
self.transitions.data[:, self.stop_tag_idx] = -10000
199-
200-
self.hidden = self.init_hidden()
201-
202-
def init_hidden(self):
203-
return (torch.randn(2, 1, self.hidden_dim // 2),
204-
torch.randn(2, 1, self.hidden_dim // 2))
171+
def get_loss(self, logits, y, s_lens):
172+
loss = self.neg_log_likelihood(logits, y, s_lens)
173+
if self.config.is_l2_loss:
174+
loss += self.get_l2_loss()
175+
return loss
205176

206-
def get_emission_prob(self, sentence, lengths):
207-
embedded = self.word_embeds(sentence)
208-
lengths = lengths.view(-1).tolist()
209-
packed = pack_padded_sequence(embedded, lengths, batch_first=True)
177+
def get_l2_loss(self):
178+
l2_reg = sum(p.norm(2) for p in self.parameters() if p.requires_grad)
179+
return self.config.reg_lambda * l2_reg
210180

211-
self.hidden = self.init_hidden()
212-
output, self.hidden = self.lstm(packed, self.hidden)
181+
def predict(self, logit, lengths):
182+
max_value, pred = torch.max(logit, dim=2)
183+
return pred
213184

214-
lstm_feats, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x t_k x n
215-
lstm_feats = lstm_feats.contiguous()
185+
class NER_SOFTMAX_CHAR_CRF(nn.Module):
186+
def __init__(self, vocab, config):
187+
super(NER_SOFTMAX_CHAR_CRF, self).__init__()
216188

217-
b, t_k, d = list(lstm_feats.size())
189+
self.featurizer = NER_SOFTMAX_CHAR(vocab, config)
190+
self.crf = CRF_Loss(len(vocab.id_to_tag))
218191

219-
logits = self.hidden2tag(lstm_feats.view(-1, d))
220-
logits = logits.view(b, t_k, -1)
192+
def forward(self, batch):
193+
emissions = self.featurizer(batch)
194+
return emissions
221195

222-
return logits
196+
def crf_loss(self, emissions, target, s_lens):
197+
loss = -1 * self.crf(emissions, target)
198+
loss = loss.squeeze(1).sum(dim=1) / s_lens.float()
199+
loss = loss.mean()
200+
return loss
223201

224-
def get_argmax(self, logits):
225-
max_value, max_idx = torch.max(logits, dim=2)
226-
return max_value, max_idx
227-
228-
def log_sum_exp(self, vec):
229-
max_score, _ = self.get_argmax(vec)
230-
max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
231-
return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
232-
233-
def get_log_z(self, emission_prob):
234-
init_alphas = torch.full((1, self.all_tagset_size), -10000.)
235-
init_alphas[0][self.start_tag_idx] = 0.
236-
forward_var = init_alphas
237-
for e_i in emission_prob:
238-
alphas_t = []
239-
for next_tag in range(self.all_tagset_size):
240-
emit_score = e_i[next_tag].view(
241-
1, -1).expand(1, self.all_tagset_size)
242-
trans_score = self.transitions[next_tag].view(1, -1)
243-
244-
next_tag_var = forward_var + trans_score + emit_score
245-
alphas_t.append(self.log_sum_exp(next_tag_var).view(1))
246-
forward_var = torch.cat(alphas_t).view(1, -1)
247-
terminal_var = forward_var + self.transitions[self.stop_tag_idx]
248-
alpha = self.log_sum_exp(terminal_var)
249-
return alpha
250-
251-
def get_log_p_y_x(self, feats, lengths, tags):
252-
score = torch.zeros(1)
253-
tags = torch.cat([torch.tensor([self.start_tag_idx], dtype=torch.long), tags])
254-
for i, feat in enumerate(feats):
255-
score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
256-
score = score + self.transitions[self.stop_tag_idx, tags[-1]]
257-
return score
258-
259-
def neg_log_likelihood(self, sentence, lengths, tags):
260-
feats = self.get_emission_prob(sentence, lengths)
261-
log_z = self.get_log_z(feats, lengths)
262-
log_p_y_x = self.get_log_p_y_x(feats, tags, lengths)
263-
return -(log_p_y_x - log_z)
264-
265-
def forward(self, sentence, lengths):
266-
feats = self.get_emission_prob(sentence, lengths)
267-
score, tag_seq = self.viterbi_decode(feats, lengths)
268-
return score, tag_seq
269-
270-
def viterbi_decode(self, feats, lengths):
271-
backpointers = []
272-
273-
init_vvars = torch.full((1, self.all_tagset_size), -10000.)
274-
init_vvars[0][self.start_tag_idx] = 0
275-
276-
forward_var = init_vvars
277-
for feat in feats:
278-
bptrs_t = []
279-
viterbivars_t = []
280-
281-
for next_tag in range(self.all_tagset_size):
282-
next_tag_var = forward_var + self.transitions[next_tag]
283-
_, best_tag_id = self.get_argmax(next_tag_var)
284-
bptrs_t.append(best_tag_id)
285-
viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
286-
forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
287-
backpointers.append(bptrs_t)
288-
289-
terminal_var = forward_var + self.transitions[self.stop_tag_idx]
290-
_, best_tag_id = self.get_argmax(terminal_var)
291-
path_score = terminal_var[0][best_tag_id]
292-
293-
best_path = [best_tag_id]
294-
for bptrs_t in reversed(backpointers):
295-
best_tag_id = bptrs_t[best_tag_id]
296-
best_path.append(best_tag_id)
297-
start = best_path.pop()
298-
assert start == self.start_tag_idx
299-
best_path.reverse()
300-
return path_score, best_path
202+
def get_loss(self, logits, y, s_lens):
203+
loss = self.crf_loss(logits, y, s_lens)
204+
if self.config.is_l2_loss:
205+
loss += self.get_l2_loss()
206+
return loss
301207

208+
def predict(self, emissions, lengths):
209+
best_scores, pred = self.crf.viterbi_decode(emissions, lengths)
210+
return pred

neural_ner/process_training.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
from data_utils.batcher import DatasetConll2003
1111
from data_utils.vocab import Vocab
12-
from model import get_model
13-
from model import test_one_batch
12+
from model import get_model, test_one_batch
1413
from train_utils import setup_train_dir, save_model, write_summary, \
1514
get_param_norm, get_grad_norm, Evaluter
1615

@@ -29,7 +28,7 @@ def train_one_batch(self, batch, optimizer, params):
2928
s_lengths = batch['words_lens']
3029
y = batch['tags']
3130
logits = self.model(batch)
32-
loss = self.model.neg_log_likelihood(logits, y, s_lengths)
31+
loss = self.model.get_loss(logits, y, s_lengths)
3332

3433
loss.backward()
3534

@@ -45,10 +44,7 @@ def train(self):
4544
train_dir, summary_writer = setup_train_dir(self.config)
4645

4746
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
48-
if self.config.optimizer == 'adam':
49-
optimizer = Adam(params, lr=0.001, amsgrad=True)
50-
elif self.config.optimizer == 'sdg':
51-
optimizer = SGD(params, lr=0.01)
47+
optimizer = Adam(params, lr=0.001, amsgrad=True)
5248

5349
num_params = sum(p.numel() for p in params)
5450
logging.info("Number of params: %d" % num_params)
@@ -131,7 +127,7 @@ def evaluate(self, data_type, num_samples=None):
131127
y = batch['tags']
132128

133129
logits, pred = test_one_batch(batch, self.model)
134-
loss = self.model.neg_log_likelihood(logits, y, s_lengths)
130+
loss = self.model.get_loss(logits, y, s_lengths)
135131

136132
curr_batch_size = len(batch['raw_sentence'])
137133
loss_per_batch += loss * curr_batch_size

0 commit comments

Comments
 (0)