Skip to content

Commit 9f6c4d0

Browse files
committed
crf implementation
1 parent f2e1ba5 commit 9f6c4d0

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed

neural_ner/crf.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from __future__ import unicode_literals, print_function, division
2+
3+
import torch
4+
import torch.nn as nn
5+
from data_utils.sentence_utils import Constants
6+
7+
8+
def get_mask(lengths):
9+
seq_lens = lengths.view(-1, 1)
10+
max_len = torch.max(seq_lens)
11+
range_tensor = torch.arange(max_len).unsqueeze(0)
12+
range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1))
13+
mask = (range_tensor < seq_lens).float()
14+
return mask
15+
16+
class CRF_Loss(nn.Module):
17+
def __init__(self, tagset_size):
18+
super(CRF_Loss, self).__init__()
19+
self.start_tag_idx = tagset_size
20+
self.stop_tag_idx = tagset_size + 1
21+
self.num_tags = tagset_size + 2
22+
23+
#transition from y_i-1 to y_i, T[y_i, y_j] = y_i <= y_j
24+
#+2 added for start and end indices
25+
self.transitions = nn.Parameter(torch.Tensor(self.num_tags, self.num_tags))
26+
nn.init.uniform_(self.transitions, -0.1, 0.1)
27+
28+
#no transition to start_tag, not transition from end tag
29+
self.transitions.data[self.start_tag_idx, :] = -10000
30+
self.transitions.data[:, self.stop_tag_idx] = -10000
31+
32+
def get_log_p_z(self, emissions, mask, seq_len):
33+
log_alpha = emissions[:, 0].clone()
34+
log_alpha += self.transitions[self.start_tag_idx, : self.start_tag_idx].unsqueeze(0)
35+
36+
for idx in range(1, seq_len):
37+
broadcast_emissions = emissions[:, idx].unsqueeze(1)
38+
broadcast_transitions = self.transitions[
39+
: self.start_tag, : self.start_tag
40+
].unsqueeze(0)
41+
broadcast_logprob = log_alpha.unsqueeze(2)
42+
score = broadcast_logprob + broadcast_emissions + broadcast_transitions
43+
44+
score = torch.logsumexp(score, 1)
45+
log_alpha = score * mask[:, idx].unsqueeze(1) + log_alpha.squeeze(1) * (
46+
1.0 - mask[:, idx].unsqueeze(1)
47+
)
48+
49+
log_alpha += self.transitions[: self.start_tag, self.end_tag].unsqueeze(0)
50+
return torch.logsumexp(log_alpha.squeeze(1), 1)
51+
52+
def get_log_p_Y_X(self, emissions, mask, seq_len, tags):
53+
llh = self.transitions[self.start_tag, tags[:, 0]].unsqueeze(1)
54+
llh += emissions[:, 0, :].gather(1, tags[:, 0].view(-1, 1)) * mask[:, 0].unsqueeze(1)
55+
56+
for idx in range(1, seq_len):
57+
old_state, new_state = (
58+
tags[:, idx - 1].view(-1, 1),
59+
tags[:, idx].view(-1, 1),
60+
)
61+
emission_scores = emissions[:, idx, :].gather(1, new_state)
62+
transition_scores = self.transitions[old_state, new_state]
63+
llh += (emission_scores + transition_scores) * mask[:, idx].unsqueeze(1)
64+
65+
last_tag_indices = mask.sum(1, dtype=torch.long) - 1
66+
last_tags = tags.gather(1, last_tag_indices.view(-1, 1))
67+
68+
llh += self.transitions[last_tags.squeeze(1), self.end_tag].unsqueeze(1)
69+
70+
return llh.squeeze(1)
71+
72+
def log_likelihood(self, emissions, tags):
73+
mask = tags.ne(Constants.TAG_PAD_ID).float()
74+
seq_len = emissions.shape[1]
75+
log_z = self.get_log_p_z(emissions, mask, seq_len)
76+
log_p_y_x = self.get_log_p_Y_X(emissions, mask, seq_len, tags)
77+
return log_p_y_x - log_z
78+
79+
def forward(self, emissions, tags):
80+
return self.log_likelihood(emissions, tags)
81+
82+
def inference(self, emissions, lengths):
83+
return self.viterbi_decode(emissions, lengths)
84+
85+
def viterbi_decode(self, emissions, lengths):
86+
mask = get_mask(lengths)
87+
seq_len = emissions.shape[1]
88+
89+
log_prob = emissions[:, 0].clone()
90+
log_prob += self.transitions[self.start_tag, : self.start_tag].unsqueeze(0)
91+
92+
93+
end_scores = log_prob + self.transitions[
94+
: self.start_tag, self.end_tag
95+
].unsqueeze(0)
96+
97+
best_scores_list = []
98+
best_scores_list.append(end_scores.unsqueeze(1))
99+
100+
best_paths_list = [torch.Tensor().long()]
101+
102+
for idx in range(1, seq_len):
103+
broadcast_emissions = emissions[:, idx].unsqueeze(1)
104+
broadcast_transmissions = self.transitions[
105+
: self.start_tag, : self.start_tag
106+
].unsqueeze(0)
107+
broadcast_log_prob = log_prob.unsqueeze(2)
108+
109+
score = broadcast_emissions + broadcast_transmissions + broadcast_log_prob
110+
111+
max_scores, max_score_indices = torch.max(score, 1)
112+
113+
best_paths_list.append(max_score_indices.unsqueeze(1))
114+
115+
end_scores = max_scores + self.transitions[
116+
: self.start_tag, self.end_tag
117+
].unsqueeze(0)
118+
119+
best_scores_list.append(end_scores.unsqueeze(1))
120+
log_prob = max_scores
121+
122+
best_scores = torch.cat(best_scores_list, 1).float()
123+
best_paths = torch.cat(best_paths_list, 1)
124+
125+
_, max_indices_from_scores = torch.max(best_scores, 2)
126+
127+
valid_index_tensor = torch.tensor(0).long()
128+
padding_tensor = torch.tensor(Constants.PAD_ID).long()
129+
130+
labels = max_indices_from_scores[:, seq_len - 1]
131+
labels = self._mask_tensor(labels, 1.0 - mask[:, seq_len - 1], padding_tensor)
132+
133+
all_labels = labels.unsqueeze(1).long()
134+
135+
for idx in range(seq_len - 2, -1, -1):
136+
indices_for_lookup = all_labels[:, -1].clone()
137+
indices_for_lookup = torch.where(
138+
indices_for_lookup == self.ignore_index,
139+
valid_index_tensor,
140+
indices_for_lookup
141+
)
142+
143+
indices_from_prev_pos = (
144+
best_paths[:, idx, :]
145+
.gather(1, indices_for_lookup.view(-1, 1).long())
146+
.squeeze(1)
147+
)
148+
indices_from_prev_pos = torch.where(
149+
(1.0 - mask[:, idx + 1]),
150+
padding_tensor,
151+
indices_from_prev_pos
152+
)
153+
154+
indices_from_max_scores = max_indices_from_scores[:, idx]
155+
indices_from_max_scores = torch.where(
156+
mask[:, idx + 1],
157+
padding_tensor,
158+
indices_from_max_scores
159+
)
160+
161+
labels = torch.where(
162+
indices_from_max_scores == self.ignore_index,
163+
indices_from_prev_pos,
164+
indices_from_max_scores,
165+
)
166+
167+
# Set to ignore_index if present state is not valid.
168+
labels = torch.where(
169+
(1 - mask[:, idx]),
170+
padding_tensor,
171+
labels
172+
)
173+
all_labels = torch.cat((all_labels, labels.view(-1, 1).long()), 1)
174+
175+
return best_scores, torch.flip(all_labels, [1])
176+

0 commit comments

Comments
 (0)