Skip to content

Commit d55b6f3

Browse files
committed
fix char embedding
1 parent 057aa20 commit d55b6f3

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

neural_ner/data_utils/batcher.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def prepare_batch(self, batch):
6666
all_chars = []
6767
raw_sentences = []
6868

69-
max_length = max(words_len)
70-
7169
for i in idx:
7270
datum = batch[i]
7371
for v in features:
@@ -76,7 +74,7 @@ def prepare_batch(self, batch):
7674
raw_sentences.append(datum['raw_sentence'])
7775

7876
chars = datum['chars']
79-
chars_padded, chars_padded_lens = pad_chars(chars, max_length)
77+
chars_padded, chars_padded_lens = pad_chars(chars)
8078
chars_padded = torch.Tensor(chars_padded).long()
8179
chars_padded_lens = torch.Tensor(chars_padded_lens).long()
8280
if self.config.is_cuda:

neural_ner/data_utils/sentence_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def pad_items(items, is_tag=False):
1515

1616
return np.array(padded_items), np.array(padded_items_len)
1717

18-
def pad_chars(items, max_word_len):
18+
def pad_chars(items):
1919
padded_items = []
2020
padded_items_len = [len(item) for item in items]
2121
max_length = max(padded_items_len)
@@ -24,10 +24,6 @@ def pad_chars(items, max_word_len):
2424
for item in items:
2525
padding = [pad_id] * (max_length - len(item))
2626
padded_items.append(item + padding)
27-
for i in range(len(items), max_word_len):
28-
padding = [pad_id] * max_length
29-
padded_items.append(padding)
30-
padded_items_len.append(1)
3127

3228
return np.array(padded_items), np.array(padded_items_len)
3329

@@ -62,8 +58,8 @@ def prepare_sentence(s, vocab, config):
6258
str_words = [w[0] for w in s]
6359
word_seq, word_char_seq = get_char_word_seq(str_words, config.lower, config.zeros)
6460

65-
words = [vocab.word_to_id[w] if w in vocab.word_to_id else Constants.UNK_ID for w in word_seq]
66-
chars = [[vocab.char_to_id[c] for c in char_seq if c in vocab.char_to_id] for char_seq in word_char_seq]
61+
words = [vocab.word_to_id.get(w, Constants.UNK_ID) for w in word_seq]
62+
chars = [[vocab.char_to_id.get(c, Constants.UNK_ID) for c in char_seq] for char_seq in word_char_seq]
6763
caps = [cap_feature(w) for w in str_words]
6864

6965
return {
@@ -74,3 +70,16 @@ def prepare_sentence(s, vocab, config):
7470
'raw_sentence':str_words
7571
}
7672

73+
if __name__ == '__main__':
74+
from config import config
75+
from data_utils.vocab import Vocab
76+
from data_utils.utils import prepare_dataset
77+
78+
vocab = Vocab(config)
79+
80+
sentences = [['a O', 'b O', 'c O', '| O']]
81+
82+
data = prepare_dataset(sentences, vocab, config)
83+
datum = data[0]
84+
chars = datum['chars']
85+
chars_padded, chars_padded_lens = pad_chars(chars)

neural_ner/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,17 @@ def __init__(self, vocab, config):
5555
def forward(self, batch):
5656
sentence = batch['words']
5757
lengths = batch['words_lens']
58+
5859
if self.config.is_caps:
5960
caps = batch['caps']
60-
61+
max_length = torch.max(lengths)
6162
char_emb = []
6263
word_embed = self.word_embeds(sentence)
6364
for chars, char_len in batch['chars']:
6465
seq_embed = self.char_embeds(chars)
6566
seq_lengths, sort_idx = torch.sort(char_len, descending=True)
6667
_, unsort_idx = torch.sort(sort_idx)
6768
seq_embed = seq_embed[sort_idx]
68-
6969
packed = pack_padded_sequence(seq_embed, seq_lengths, batch_first=True)
7070
output, hidden = self.lstm_char(packed)
7171
lstm_feats, _ = pad_packed_sequence(output, batch_first=True)
@@ -79,7 +79,10 @@ def forward(self, batch):
7979
seq_rep_bwd = seq_rep[unsort_idx, last_idx, 1]
8080

8181
seq_out = torch.cat([seq_rep_fwd, seq_rep_bwd], 1)
82+
# fill up the dummy char embedding for padding
83+
seq_out = F.pad(seq_out, (0, 0, 0, max_length - seq_out.size(0)))
8284
char_emb.append(seq_out.unsqueeze(0))
85+
8386
char_emb = torch.cat(char_emb, 0) #b x n x c_dim
8487

8588
if self.config.is_caps:

neural_ner/process_training.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def evalute_test_dev(self, summary_writer, epoch, global_step, exp_loss):
108108
logging.info("Dev: Epoch %d, Iter %d, loss: %f, F1: %f" % (epoch, global_step, dev_loss, dev_f1))
109109
write_summary(dev_f1, "dev/F1", summary_writer, global_step)
110110

111-
self.evaluate_test()
112-
113111
return dev_f1
114112

115113
def inference(self, str):

0 commit comments

Comments
 (0)