@@ -203,15 +203,21 @@ cdef class Parser(TrainablePipe):
203
203
)
204
204
205
205
def greedy_parse (self , docs , drop = 0. ):
206
- cdef vector[StateC* ] states
207
- cdef StateClass state
208
206
set_dropout_rate(self .model, drop)
209
- batch = self .moves.init_batch(docs)
210
207
# This is pretty dirty, but the NER can resize itself in init_batch,
211
208
# if labels are missing. We therefore have to check whether we need to
212
209
# expand our model output.
213
210
self ._resize()
214
211
model = self .model.predict(docs)
212
+ batch = self .moves.init_batch(docs)
213
+ states = self ._predict_states(model, batch)
214
+ model.clear_memory()
215
+ del model
216
+ return states
217
+
218
+ def _predict_states (self , model , batch ):
219
+ cdef vector[StateC* ] states
220
+ cdef StateClass state
215
221
weights = get_c_weights(model)
216
222
for state in batch:
217
223
if not state.is_final():
@@ -220,8 +226,6 @@ cdef class Parser(TrainablePipe):
220
226
with nogil:
221
227
self ._parseC(& states[0 ],
222
228
weights, sizes)
223
- model.clear_memory()
224
- del model
225
229
return batch
226
230
227
231
def beam_parse (self , docs , int beam_width , float drop = 0. , beam_density = 0. ):
@@ -306,6 +310,7 @@ cdef class Parser(TrainablePipe):
306
310
else :
307
311
action = self .moves.c[guess]
308
312
action.do(states[i], action.label)
313
+ states[i].history.push_back(guess)
309
314
free(is_valid)
310
315
311
316
def update (self , examples , *, drop = 0. , sgd = None , losses = None ):
@@ -319,7 +324,7 @@ cdef class Parser(TrainablePipe):
319
324
# We need to take care to act on the whole batch, because we might be
320
325
# getting vectors via a listener.
321
326
n_examples = len ([eg for eg in examples if self .moves.has_gold(eg)])
322
- if len (examples) == 0 :
327
+ if n_examples == 0 :
323
328
return losses
324
329
set_dropout_rate(self .model, drop)
325
330
# The probability we use beam update, instead of falling back to
@@ -333,23 +338,25 @@ cdef class Parser(TrainablePipe):
333
338
losses = losses,
334
339
beam_density = self .cfg[" beam_density" ]
335
340
)
336
- oracle_histories = [self .moves.get_oracle_sequence(eg) for eg in examples]
341
+ model, backprop_tok2vec = self .model.begin_update([eg.x for eg in examples])
342
+ final_states = self .moves.init_batch([eg.x for eg in examples])
343
+ self ._predict_states(model, final_states)
344
+ histories = [list (state.history) for state in final_states]
345
+ # oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
337
346
max_moves = self .cfg[" update_with_oracle_cut_size" ]
338
347
if max_moves >= 1 :
339
348
# Chop sequences into lengths of this many words, to make the
340
349
# batch uniform length.
341
350
max_moves = int (random.uniform(max_moves // 2 , max_moves * 2 ))
342
351
states, golds, _ = self ._init_gold_batch(
343
352
examples,
344
- oracle_histories ,
353
+ histories ,
345
354
max_length = max_moves
346
355
)
347
356
else :
348
357
states, golds, _ = self .moves.init_gold_batch(examples)
349
358
if not states:
350
359
return losses
351
- docs = [eg.predicted for eg in examples]
352
- model, backprop_tok2vec = self .model.begin_update(docs)
353
360
354
361
all_states = list (states)
355
362
states_golds = list (zip (states, golds))
@@ -373,15 +380,7 @@ cdef class Parser(TrainablePipe):
373
380
backprop_tok2vec(golds)
374
381
if sgd not in (None , False ):
375
382
self .finish_update(sgd)
376
- # If we want to set the annotations based on predictions, it's really
377
- # hard to avoid parsing the data twice :(.
378
- # The issue is that we cut up the gold batch into sub-states, and that
379
- # means there's no one predicted sequence during the update.
380
- gold_states = [
381
- self .moves.follow_history(doc, history)
382
- for doc, history in zip (docs, oracle_histories)
383
- ]
384
- self .set_annotations(docs, gold_states)
383
+ self .set_annotations([eg.x for eg in examples], final_states)
385
384
# Ugh, this is annoying. If we're working on GPU, we want to free the
386
385
# memory ASAP. It seems that Python doesn't necessarily get around to
387
386
# removing these in time if we don't explicitly delete? It's confusing.
@@ -599,6 +598,7 @@ cdef class Parser(TrainablePipe):
599
598
StateClass state
600
599
Transition action
601
600
all_states = self .moves.init_batch([eg.predicted for eg in examples])
601
+ assert len (all_states) == len (examples) == len (oracle_histories)
602
602
states = []
603
603
golds = []
604
604
for state, eg, history in zip (all_states, examples, oracle_histories):
@@ -616,6 +616,7 @@ cdef class Parser(TrainablePipe):
616
616
for clas in history[i:i+ max_length]:
617
617
action = self .moves.c[clas]
618
618
action.do(state.c, action.label)
619
+ state.c.history.push_back(clas)
619
620
if state.is_final():
620
621
break
621
622
if self .moves.has_gold(eg, start_state.B(0 ), state.B(0 )):
0 commit comments