Skip to content

Commit a4964b5

Browse files
committed
crf
1 parent e76c499 commit a4964b5

File tree

5 files changed

+16
-12
lines changed

5 files changed

+16
-12
lines changed

neural_ner/crf.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def get_log_p_z(self, emissions, mask, seq_len):
4242
log_alpha += self.transitions[: self.start_tag, self.end_tag].unsqueeze(0)
4343
return torch.logsumexp(log_alpha.squeeze(1), 1)
4444

45-
def get_log_p_Y_X(self, emissions, mask, seq_len, tags):
46-
tags[tags < 0] = 0 # clone and then set
45+
def get_log_p_Y_X(self, emissions, mask, seq_len, orig_tags):
46+
tags = orig_tags.clone()
47+
tags[tags < 0] = 0
4748

4849
llh = self.transitions[self.start_tag, tags[:, 0]].unsqueeze(1)
4950
llh += emissions[:, 0, :].gather(1, tags[:, 0].view(-1, 1)) * mask[:, 0].unsqueeze(1)
@@ -106,28 +107,28 @@ def viterbi_decode(self, emissions, lengths):
106107
_, max_indices_from_scores = torch.max(best_scores, 2)
107108

108109
valid_index_tensor = torch.tensor(0).long()
109-
padding_tensor = torch.tensor(Constants.PAD_ID).long()
110+
padding_tensor = torch.tensor(Constants.TAG_PAD_ID).long()
110111

111112
labels = max_indices_from_scores[:, seq_len - 1]
112-
labels = torch.where(1.0 - mask[:, seq_len - 1], padding_tensor, labels)
113+
labels = torch.where(mask[:, seq_len - 1] != 1.0, padding_tensor, labels)
113114
all_labels = labels.unsqueeze(1).long()
114115

115116
for idx in range(seq_len - 2, -1, -1):
116117
indices_for_lookup = all_labels[:, -1].clone()
117-
indices_for_lookup = torch.where(indices_for_lookup == self.ignore_index, valid_index_tensor,
118+
indices_for_lookup = torch.where(indices_for_lookup == Constants.TAG_PAD_ID, valid_index_tensor,
118119
indices_for_lookup)
119120

120121
indices_from_prev_pos = best_paths[:, idx, :].gather(1, indices_for_lookup.view(-1, 1).long()).squeeze(1)
121-
indices_from_prev_pos = torch.where((1.0 - mask[:, idx + 1]), padding_tensor, indices_from_prev_pos)
122+
indices_from_prev_pos = torch.where(mask[:, idx + 1] != 1.0, padding_tensor, indices_from_prev_pos)
122123

123124
indices_from_max_scores = max_indices_from_scores[:, idx]
124-
indices_from_max_scores = torch.where(mask[:, idx + 1], padding_tensor, indices_from_max_scores)
125+
indices_from_max_scores = torch.where(mask[:, idx + 1] == 1.0, padding_tensor, indices_from_max_scores)
125126

126-
labels = torch.where(indices_from_max_scores == self.ignore_index, indices_from_prev_pos,
127+
labels = torch.where(indices_from_max_scores == Constants.TAG_PAD_ID, indices_from_prev_pos,
127128
indices_from_max_scores)
128129

129130
# Set to ignore_index if present state is not valid.
130-
labels = torch.where((1 - mask[:, idx]),padding_tensor, labels)
131+
labels = torch.where(mask[:, idx] != 1.0, padding_tensor, labels)
131132
all_labels = torch.cat((all_labels, labels.view(-1, 1).long()), 1)
132133

133134
return best_scores, torch.flip(all_labels, [1])

neural_ner/data_utils/vocab.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,15 @@ def word_mapping(self, sentences):
6161
self.word_to_id = {v: k for k, v in id_to_word.items()}
6262
self.id_to_word = id_to_word
6363

64-
char_freq_map = create_freq_map(chars)
65-
66-
id_to_char = {i+start_vocab_len: v for i, v in enumerate(char_freq_map)}
64+
id_to_char = {}
6765
for i, v in enumerate(Constants._START_VOCAB):
6866
id_to_char[i] = v
6967

68+
char_freq_map = create_freq_map(chars)
69+
70+
for v in char_freq_map:
71+
id_to_char[len(id_to_char)] = v
72+
7073
print("Found {} unique characters".format(len(char_freq_map)))
7174

7275
self.char_to_id = {v: k for k, v in id_to_char.items()}

semantic_parsing/__init__.py

Whitespace-only changes.

semantic_parsing/data_utils/__init__.py

Whitespace-only changes.

semantic_parsing/data_utils/batcher.py

Whitespace-only changes.

0 commit comments

Comments
 (0)