Skip to content

Commit 91e8420

Browse files
committed
top k viterbi
1 parent d55b6f3 commit 91e8420

File tree

4 files changed

+88
-7
lines changed

4 files changed

+88
-7
lines changed

neural_ner/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class Config(object):
6666

6767
config.verbose = False
6868
config.is_caps=True
69-
config.is_structural_perceptron_loss=False
69+
config.is_structural_perceptron_loss=True
7070

7171
config.input_format='conll2003' #crfsuite
7272

neural_ner/crf.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import torch
44
import torch.nn as nn
55
from data_utils.constant import Constants
6+
from model_utils import get_mask
7+
68
import numpy as np
9+
import logging
10+
11+
logging.basicConfig(level=logging.INFO)
712

813
class CRF_Loss(nn.Module):
914
def __init__(self, tagset_size, config):
@@ -66,7 +71,8 @@ def log_likelihood(self, emissions, tags):
6671
log_p_y_x = self.get_log_p_Y_X(emissions, mask, tags)
6772
return log_p_y_x - log_z
6873

69-
def viterbi_decode(self, emissions, mask):
74+
def viterbi_decode_batch(self, emissions, lengths):
75+
mask = get_mask(lengths, self.config)
7076
seq_len = emissions.shape[1]
7177

7278
log_prob = emissions[:, 0].clone()
@@ -133,10 +139,52 @@ def viterbi_decode(self, emissions, mask):
133139

134140
return sentence_score, torch.flip(all_labels, [1])
135141

142+
def viterbi_decode(self, emissions, lengths):
143+
bsz = emissions.shape[0]
144+
all_path_indices = []
145+
all_path_scores = []
146+
147+
for i in range(bsz):
148+
viterbi_path, viterbi_score = self.viterbi_decode_single(lengths[i], emissions[i])
149+
all_path_indices.append(viterbi_path)
150+
all_path_scores.append(viterbi_score)
151+
152+
return all_path_indices, all_path_scores
153+
154+
def viterbi_decode_single(self, sequence_length, emission, top_k=1):
155+
num_tags = emission.shape[0]
156+
path_scores, path_indices= [], []
157+
path_scores.append(emission[0, :].unsqueeze(0))
158+
for timestep in range(1, sequence_length):
159+
summed_potentials = path_scores[timestep - 1].unsqueeze(-1) + self.transitions
160+
scores, paths = torch.topk(summed_potentials, k=top_k, dim=0)
161+
path_scores.append(emission[timestep, :] + scores.squeeze())
162+
path_indices.append(paths.squeeze())
163+
164+
viterbi_score, best_paths = torch.topk(path_scores[-1], k=top_k, dim=0)
165+
viterbi_paths = []
166+
for i in range(top_k):
167+
viterbi_path = [best_paths[i]]
168+
for backward_timestep in reversed(path_indices):
169+
viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]]))
170+
viterbi_path.reverse()
171+
172+
viterbi_path = [j % num_tags for j in viterbi_path]
173+
viterbi_paths.append(viterbi_path)
174+
'''
175+
viterbi_path = [int(best_path.numpy())]
176+
for backward_timestep in reversed(path_indices):
177+
viterbi_path.append(int(backward_timestep[viterbi_path[-1]]))
178+
# Reverse the backward path.
179+
viterbi_path.reverse()
180+
'''
181+
return viterbi_paths, viterbi_score
182+
183+
136184
def structural_perceptron_loss(self, emissions, tags):
137185
mask = tags.ne(Constants.TAG_PAD_ID).float()
138-
139-
best_scores, pred = self.viterbi_decode(emissions, mask)
186+
sequence_lnegths = mask.sum(dim=1)
187+
best_scores, pred = self.viterbi_decode(emissions, sequence_lnegths)
140188
log_p_y_x = self.get_log_p_Y_X(emissions, mask, tags)
141189

142190
delta = torch.sum(tags.ne(pred).float()*mask, 1)

neural_ner/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99

1010
from crf import CRF_Loss
11-
from model_utils import get_mask, init_lstm_wt, init_linear_wt, get_word_embd
11+
from model_utils import init_lstm_wt, init_linear_wt, get_word_embd
1212

1313
logging.basicConfig(level=logging.INFO)
1414

@@ -158,6 +158,6 @@ def get_loss(self, logits, y, s_lens):
158158
return loss
159159

160160
def predict(self, emissions, lengths):
161-
mask = get_mask(lengths, self.config)
161+
162162
best_scores, pred = self.crf.viterbi_decode(emissions, mask)
163163
return pred

neural_ner/old/sanity_checks.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn.functional as F
3+
import numpy as np
34

45
def check_ignore_pads_in_loss():
56
logits = torch.randn(2, 3, 4, requires_grad=True)
@@ -20,6 +21,38 @@ def check_ignore_pads_in_loss():
2021
print(loss)
2122

2223
if __name__ == '__main__':
23-
check_ignore_pads_in_loss()
24+
#check_ignore_pads_in_loss()
25+
'''
26+
a_ = np.random.randint(1, 100, 5)
27+
b_ = np.random.randint(1, 100, 5)
28+
29+
ab = sorted([aa*bb for bb in b_ for aa in a_], reverse=True)
30+
31+
print (ab[:5], a_, b_)
32+
33+
import heapq
34+
a = sorted(a_, reverse=True)
35+
b = sorted(b_, reverse=True)
36+
37+
pQueue = []
38+
heapq.heappush(pQueue, (-a[0]*b[0], 0, 0))
39+
topk = []
40+
for _ in range(5):
41+
v, ia, ib = heapq.heappop(pQueue)
42+
topk.append(-v)
43+
if ia + 1 < len(a):
44+
heapq.heappush(pQueue, (-a[ia + 1]*b[ib], ia+1, ib))
45+
if ib + 1 < len(b):
46+
heapq.heappush(pQueue, (-a[ia] * b[ib+1], ia, ib+1))
47+
48+
print (topk)
49+
'''
50+
tx = np.random.randint(1, 100, size=(5,5))
51+
e = np.random.randint(1, 100, size=5)
52+
53+
print (tx)
54+
print(e)
55+
print(np.expand_dims(e, -1) + tx)
56+
2457

2558

0 commit comments

Comments
 (0)