|
| 1 | +from __future__ import unicode_literals, print_function, division |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
| 7 | +from data_utils.constant import Constants |
| 8 | +import logging |
| 9 | + |
| 10 | +from crf import CRF_Loss |
| 11 | +from model_utils import init_lstm_wt, init_linear_wt, get_word_embd, get_mask |
| 12 | + |
| 13 | +from transformer import Encoder |
| 14 | + |
| 15 | +logging.basicConfig(level=logging.INFO) |
| 16 | + |
| 17 | +class NER_SOFTMAX_CHAR(nn.Module): |
| 18 | + def __init__(self, vocab, config): |
| 19 | + super(NER_SOFTMAX_CHAR, self).__init__() |
| 20 | + word_emb_matrix = get_word_embd(vocab, config) |
| 21 | + embd_vector = torch.from_numpy(word_emb_matrix).float() |
| 22 | + |
| 23 | + self.word_embeds = nn.Embedding.from_pretrained(embd_vector, freeze=False) |
| 24 | + self.char_embeds = nn.Embedding(len(vocab.char_to_id), config.char_embd_dim, padding_idx=Constants.PAD_ID) |
| 25 | + if config.is_caps: |
| 26 | + self.caps_embeds = nn.Embedding(vocab.get_caps_cardinality(), |
| 27 | + config.caps_embd_dim, padding_idx=Constants.PAD_ID) |
| 28 | + |
| 29 | + self.lstm_char = nn.LSTM(self.char_embeds.embedding_dim, |
| 30 | + config.char_lstm_dim, |
| 31 | + num_layers=1, bidirectional=True, batch_first=True) |
| 32 | + |
| 33 | + input_size = self.word_embeds.embedding_dim + config.char_embd_dim * 2 |
| 34 | + |
| 35 | + if config.is_caps: |
| 36 | + input_size += config.caps_embd_dim |
| 37 | + |
| 38 | + self.lstm = nn.LSTM(input_size, |
| 39 | + config.word_lstm_dim, |
| 40 | + num_layers=1, bidirectional=True, batch_first=True) |
| 41 | + |
| 42 | + self.dropout = nn.Dropout(config.dropout_rate) |
| 43 | + self.hidden_layer = nn.Linear(config.word_lstm_dim * 2, config.word_lstm_dim) |
| 44 | + self.tanh_layer = torch.nn.Tanh() |
| 45 | + |
| 46 | + self.hidden2tag = nn.Linear(config.word_lstm_dim, len(vocab.id_to_tag)) |
| 47 | + |
| 48 | + self.config = config |
| 49 | + |
| 50 | + init_lstm_wt(self.lstm_char) |
| 51 | + init_lstm_wt(self.lstm) |
| 52 | + init_linear_wt(self.hidden_layer) |
| 53 | + init_linear_wt(self.hidden2tag) |
| 54 | + self.char_embeds.weight.data.uniform_(-1., 1.) |
| 55 | + if config.is_caps: |
| 56 | + self.caps_embeds.weight.data.uniform_(-1., 1.) |
| 57 | + |
| 58 | + def forward(self, batch): |
| 59 | + sentence = batch['words'] |
| 60 | + lengths = batch['words_lens'] |
| 61 | + |
| 62 | + if self.config.is_caps: |
| 63 | + caps = batch['caps'] |
| 64 | + max_length = torch.max(lengths) |
| 65 | + char_emb = [] |
| 66 | + word_embed = self.word_embeds(sentence) |
| 67 | + for chars, char_len in batch['chars']: |
| 68 | + seq_embed = self.char_embeds(chars) |
| 69 | + seq_lengths, sort_idx = torch.sort(char_len, descending=True) |
| 70 | + _, unsort_idx = torch.sort(sort_idx) |
| 71 | + seq_embed = seq_embed[sort_idx] |
| 72 | + packed = pack_padded_sequence(seq_embed, seq_lengths, batch_first=True) |
| 73 | + output, hidden = self.lstm_char(packed) |
| 74 | + lstm_feats, _ = pad_packed_sequence(output, batch_first=True) |
| 75 | + lstm_feats = lstm_feats.contiguous() |
| 76 | + b, t_k, d = list(lstm_feats.size()) |
| 77 | + |
| 78 | + seq_rep = lstm_feats.view(b, t_k, 2, -1) #0 is fwd and 1 is bwd |
| 79 | + |
| 80 | + last_idx = char_len - 1 |
| 81 | + seq_rep_fwd = seq_rep[unsort_idx, 0, 0] |
| 82 | + seq_rep_bwd = seq_rep[unsort_idx, last_idx, 1] |
| 83 | + |
| 84 | + seq_out = torch.cat([seq_rep_fwd, seq_rep_bwd], 1) |
| 85 | + # fill up the dummy char embedding for padding |
| 86 | + seq_out = F.pad(seq_out, (0, 0, 0, max_length - seq_out.size(0))) |
| 87 | + char_emb.append(seq_out.unsqueeze(0)) |
| 88 | + |
| 89 | + char_emb = torch.cat(char_emb, 0) #b x n x c_dim |
| 90 | + |
| 91 | + if self.config.is_caps: |
| 92 | + caps_embd = self.caps_embeds(caps) |
| 93 | + word_embed = torch.cat([char_emb, word_embed, caps_embd], 2) |
| 94 | + else: |
| 95 | + word_embed = torch.cat([char_emb, word_embed], 2) |
| 96 | + word_embed = self.dropout(word_embed) |
| 97 | + |
| 98 | + lengths = lengths.view(-1).tolist() |
| 99 | + packed = pack_padded_sequence(word_embed, lengths, batch_first=True) |
| 100 | + output, hidden = self.lstm(packed) |
| 101 | + |
| 102 | + lstm_feats, _ = pad_packed_sequence(output, batch_first=True) # h dim = B x t_k x n |
| 103 | + lstm_feats = lstm_feats.contiguous() |
| 104 | + |
| 105 | + b, t_k, d = list(lstm_feats.size()) |
| 106 | + |
| 107 | + h = self.hidden_layer(lstm_feats.view(-1, d)) |
| 108 | + h = self.tanh_layer(h) |
| 109 | + logits = self.hidden2tag(h) |
| 110 | + logits = logits.view(b, t_k, -1) |
| 111 | + |
| 112 | + return logits |
| 113 | + |
| 114 | + def neg_log_likelihood(self, logits, y, s_lens): |
| 115 | + log_smx = F.log_softmax(logits, dim=2) |
| 116 | + loss = F.nll_loss(log_smx.transpose(1, 2), y, ignore_index=Constants.TAG_PAD_ID, reduction='none') |
| 117 | + loss = loss.sum(dim=1) / s_lens.float() |
| 118 | + loss = loss.mean() |
| 119 | + return loss |
| 120 | + |
| 121 | + def get_loss(self, logits, y, s_lens): |
| 122 | + loss = self.neg_log_likelihood(logits, y, s_lens) |
| 123 | + if self.config.is_l2_loss: |
| 124 | + loss += self.get_l2_loss() |
| 125 | + return loss |
| 126 | + |
| 127 | + def get_l2_loss(self): |
| 128 | + l2_reg = sum(p.norm(2) for p in self.parameters() if p.requires_grad) |
| 129 | + return self.config.reg_lambda * l2_reg |
| 130 | + |
| 131 | + def predict(self, logit, lengths): |
| 132 | + max_value, pred = torch.max(logit, dim=2) |
| 133 | + return pred |
| 134 | + |
| 135 | +class NER_SOFTMAX_CHAR_CRF(nn.Module): |
| 136 | + def __init__(self, vocab, config): |
| 137 | + super(NER_SOFTMAX_CHAR_CRF, self).__init__() |
| 138 | + |
| 139 | + self.featurizer = NER_SOFTMAX_CHAR(vocab, config) |
| 140 | + self.crf = CRF_Loss(len(vocab.id_to_tag), config) |
| 141 | + self.config = config |
| 142 | + |
| 143 | + def get_l2_loss(self): |
| 144 | + l2_reg = sum(p.norm(2) for p in self.parameters() if p.requires_grad) |
| 145 | + return self.config.reg_lambda * l2_reg |
| 146 | + |
| 147 | + def forward(self, batch): |
| 148 | + emissions = self.featurizer(batch) |
| 149 | + return emissions |
| 150 | + |
| 151 | + def get_loss(self, logits, y, s_lens): |
| 152 | + if self.config.is_structural_perceptron_loss: |
| 153 | + loss = self.crf.structural_perceptron_loss(logits, y) |
| 154 | + else: |
| 155 | + loss = -1 * self.crf.log_likelihood(logits, y) |
| 156 | + |
| 157 | + loss = loss / s_lens.float() |
| 158 | + loss = loss.mean() |
| 159 | + if self.config.is_l2_loss: |
| 160 | + loss += self.get_l2_loss() |
| 161 | + return loss |
| 162 | + |
| 163 | + def predict(self, emissions, lengths): |
| 164 | + mask = get_mask(lengths, self.config.is_cuda) |
| 165 | + best_scores, pred = self.crf.viterbi_decode_batch(emissions, mask) |
| 166 | + return pred |
0 commit comments