8
8
import logging
9
9
import numpy as np
10
10
11
+ from crf import CRF_Loss
12
+
11
13
print ('pytorch version' , torch .__version__ )
12
14
13
15
logging .basicConfig (level = logging .INFO )
@@ -58,7 +60,8 @@ def get_word_embd(vocab, config):
58
60
def test_one_batch (batch , model ):
59
61
model .eval ()
60
62
logits = model (batch )
61
- _ , pred = model .get_argmax (logits )
63
+ lengths = batch ['words_lens' ]
64
+ pred = model .predict (logits , lengths )
62
65
return logits , pred
63
66
64
67
def get_model (vocab , config , model_file_path , is_eval = False ):
@@ -80,7 +83,6 @@ def __init__(self, vocab, config):
80
83
super (NER_SOFTMAX_CHAR , self ).__init__ ()
81
84
word_emb_matrix = get_word_embd (vocab , config )
82
85
embd_vector = torch .from_numpy (word_emb_matrix ).float ()
83
- tagset_size = len (vocab .id_to_tag )
84
86
85
87
self .word_embeds = nn .Embedding .from_pretrained (embd_vector , freeze = False )
86
88
self .char_embeds = nn .Embedding (len (vocab .char_to_id ), config .char_embd_dim , padding_idx = Constants .PAD_ID )
@@ -97,7 +99,8 @@ def __init__(self, vocab, config):
97
99
self .dropout = nn .Dropout (config .dropout_rate )
98
100
self .hidden_layer = nn .Linear (config .word_lstm_dim * 2 , config .word_lstm_dim )
99
101
self .tanh_layer = torch .nn .Tanh ()
100
- self .hidden2tag = nn .Linear (config .word_lstm_dim , tagset_size )
102
+
103
+ self .hidden2tag = nn .Linear (config .word_lstm_dim , len (vocab .id_to_tag ))
101
104
102
105
self .config = config
103
106
@@ -158,144 +161,50 @@ def forward(self, batch):
158
161
159
162
return logits
160
163
161
- def neg_log_likelihood (self , logits , y , s_len ):
164
+ def neg_log_likelihood (self , logits , y , s_lens ):
162
165
log_smx = F .log_softmax (logits , dim = 2 )
163
166
loss = F .nll_loss (log_smx .transpose (1 , 2 ), y , ignore_index = Constants .TAG_PAD_ID , reduce = False )
164
- loss = loss .squeeze (1 ).sum (dim = 1 ) / s_len .float ()
167
+ loss = loss .squeeze (1 ).sum (dim = 1 ) / s_lens .float ()
165
168
loss = loss .mean ()
166
- if self .config .is_l2_loss :
167
- l2_reg = sum (p .norm (2 ) for p in self .parameters () if p .requires_grad )
168
- loss += self .config .reg_lambda * l2_reg
169
169
return loss
170
170
171
- def get_argmax (self , logits ):
172
- max_value , max_idx = torch .max (logits , dim = 2 )
173
- return max_value , max_idx
174
-
175
- class NER_CRF (nn .Module ):
176
- def __init__ (self , embd_vector , hidden_dim , tagset_size ,
177
- reg_lambda ):
178
- super (NER_CRF , self ).__init__ ()
179
-
180
- self .start_tag_idx = tagset_size
181
- self .stop_tag_idx = tagset_size + 1
182
- self .all_tagset_size = tagset_size + 2
183
-
184
- self .word_embeds = nn .Embedding .from_pretrained (embd_vector )
185
- embedding_dim = self .word_embeds .embedding_dim
186
-
187
- self .lstm = nn .LSTM (embedding_dim , hidden_dim // 2 ,
188
- num_layers = 1 , bidirectional = True , batch_first = True )
189
- self .hidden2tag = nn .Linear (hidden_dim , tagset_size )
190
-
191
- #transition from y_i-1 to y_i, T[y_i, y_j] = y_i <= y_j
192
- #+2 added for start and end indices
193
- self .transitions = nn .Parameter (torch .randn (self .all_tagset_size , self .all_tagset_size ))
194
- nn .init .uniform_ (self .transitions , - 0.1 , 0.1 )
195
-
196
- #no transition to start_tag, not transition from end tag
197
- self .transitions .data [self .start_tag_idx , :] = - 10000
198
- self .transitions .data [:, self .stop_tag_idx ] = - 10000
199
-
200
- self .hidden = self .init_hidden ()
201
-
202
- def init_hidden (self ):
203
- return (torch .randn (2 , 1 , self .hidden_dim // 2 ),
204
- torch .randn (2 , 1 , self .hidden_dim // 2 ))
171
+ def get_loss (self , logits , y , s_lens ):
172
+ loss = self .neg_log_likelihood (logits , y , s_lens )
173
+ if self .config .is_l2_loss :
174
+ loss += self .get_l2_loss ()
175
+ return loss
205
176
206
- def get_emission_prob (self , sentence , lengths ):
207
- embedded = self .word_embeds (sentence )
208
- lengths = lengths .view (- 1 ).tolist ()
209
- packed = pack_padded_sequence (embedded , lengths , batch_first = True )
177
+ def get_l2_loss (self ):
178
+ l2_reg = sum (p .norm (2 ) for p in self .parameters () if p .requires_grad )
179
+ return self .config .reg_lambda * l2_reg
210
180
211
- self .hidden = self .init_hidden ()
212
- output , self .hidden = self .lstm (packed , self .hidden )
181
+ def predict (self , logit , lengths ):
182
+ max_value , pred = torch .max (logit , dim = 2 )
183
+ return pred
213
184
214
- lstm_feats , _ = pad_packed_sequence (output , batch_first = True ) # h dim = B x t_k x n
215
- lstm_feats = lstm_feats .contiguous ()
185
+ class NER_SOFTMAX_CHAR_CRF (nn .Module ):
186
+ def __init__ (self , vocab , config ):
187
+ super (NER_SOFTMAX_CHAR_CRF , self ).__init__ ()
216
188
217
- b , t_k , d = list (lstm_feats .size ())
189
+ self .featurizer = NER_SOFTMAX_CHAR (vocab , config )
190
+ self .crf = CRF_Loss (len (vocab .id_to_tag ))
218
191
219
- logits = self .hidden2tag (lstm_feats .view (- 1 , d ))
220
- logits = logits .view (b , t_k , - 1 )
192
+ def forward (self , batch ):
193
+ emissions = self .featurizer (batch )
194
+ return emissions
221
195
222
- return logits
196
+ def crf_loss (self , emissions , target , s_lens ):
197
+ loss = - 1 * self .crf (emissions , target )
198
+ loss = loss .squeeze (1 ).sum (dim = 1 ) / s_lens .float ()
199
+ loss = loss .mean ()
200
+ return loss
223
201
224
- def get_argmax (self , logits ):
225
- max_value , max_idx = torch .max (logits , dim = 2 )
226
- return max_value , max_idx
227
-
228
- def log_sum_exp (self , vec ):
229
- max_score , _ = self .get_argmax (vec )
230
- max_score_broadcast = max_score .view (1 , - 1 ).expand (1 , vec .size ()[1 ])
231
- return max_score + torch .log (torch .sum (torch .exp (vec - max_score_broadcast )))
232
-
233
- def get_log_z (self , emission_prob ):
234
- init_alphas = torch .full ((1 , self .all_tagset_size ), - 10000. )
235
- init_alphas [0 ][self .start_tag_idx ] = 0.
236
- forward_var = init_alphas
237
- for e_i in emission_prob :
238
- alphas_t = []
239
- for next_tag in range (self .all_tagset_size ):
240
- emit_score = e_i [next_tag ].view (
241
- 1 , - 1 ).expand (1 , self .all_tagset_size )
242
- trans_score = self .transitions [next_tag ].view (1 , - 1 )
243
-
244
- next_tag_var = forward_var + trans_score + emit_score
245
- alphas_t .append (self .log_sum_exp (next_tag_var ).view (1 ))
246
- forward_var = torch .cat (alphas_t ).view (1 , - 1 )
247
- terminal_var = forward_var + self .transitions [self .stop_tag_idx ]
248
- alpha = self .log_sum_exp (terminal_var )
249
- return alpha
250
-
251
- def get_log_p_y_x (self , feats , lengths , tags ):
252
- score = torch .zeros (1 )
253
- tags = torch .cat ([torch .tensor ([self .start_tag_idx ], dtype = torch .long ), tags ])
254
- for i , feat in enumerate (feats ):
255
- score = score + self .transitions [tags [i + 1 ], tags [i ]] + feat [tags [i + 1 ]]
256
- score = score + self .transitions [self .stop_tag_idx , tags [- 1 ]]
257
- return score
258
-
259
- def neg_log_likelihood (self , sentence , lengths , tags ):
260
- feats = self .get_emission_prob (sentence , lengths )
261
- log_z = self .get_log_z (feats , lengths )
262
- log_p_y_x = self .get_log_p_y_x (feats , tags , lengths )
263
- return - (log_p_y_x - log_z )
264
-
265
- def forward (self , sentence , lengths ):
266
- feats = self .get_emission_prob (sentence , lengths )
267
- score , tag_seq = self .viterbi_decode (feats , lengths )
268
- return score , tag_seq
269
-
270
- def viterbi_decode (self , feats , lengths ):
271
- backpointers = []
272
-
273
- init_vvars = torch .full ((1 , self .all_tagset_size ), - 10000. )
274
- init_vvars [0 ][self .start_tag_idx ] = 0
275
-
276
- forward_var = init_vvars
277
- for feat in feats :
278
- bptrs_t = []
279
- viterbivars_t = []
280
-
281
- for next_tag in range (self .all_tagset_size ):
282
- next_tag_var = forward_var + self .transitions [next_tag ]
283
- _ , best_tag_id = self .get_argmax (next_tag_var )
284
- bptrs_t .append (best_tag_id )
285
- viterbivars_t .append (next_tag_var [0 ][best_tag_id ].view (1 ))
286
- forward_var = (torch .cat (viterbivars_t ) + feat ).view (1 , - 1 )
287
- backpointers .append (bptrs_t )
288
-
289
- terminal_var = forward_var + self .transitions [self .stop_tag_idx ]
290
- _ , best_tag_id = self .get_argmax (terminal_var )
291
- path_score = terminal_var [0 ][best_tag_id ]
292
-
293
- best_path = [best_tag_id ]
294
- for bptrs_t in reversed (backpointers ):
295
- best_tag_id = bptrs_t [best_tag_id ]
296
- best_path .append (best_tag_id )
297
- start = best_path .pop ()
298
- assert start == self .start_tag_idx
299
- best_path .reverse ()
300
- return path_score , best_path
202
+ def get_loss (self , logits , y , s_lens ):
203
+ loss = self .crf_loss (logits , y , s_lens )
204
+ if self .config .is_l2_loss :
205
+ loss += self .get_l2_loss ()
206
+ return loss
301
207
208
+ def predict (self , emissions , lengths ):
209
+ best_scores , pred = self .crf .viterbi_decode (emissions , lengths )
210
+ return pred
0 commit comments