@@ -317,8 +317,8 @@ cdef class Parser(TrainablePipe):
317
317
for multitask in self ._multitasks:
318
318
multitask.update(examples, drop = drop, sgd = sgd)
319
319
320
- n_examples = len ( [eg for eg in examples if self .moves.has_gold(eg)])
321
- if n_examples == 0 :
320
+ examples = [eg for eg in examples if self .moves.has_gold(eg)]
321
+ if len (examples) == 0 :
322
322
return losses
323
323
set_dropout_rate(self .model, drop)
324
324
# The probability we use beam update, instead of falling back to
@@ -332,13 +332,15 @@ cdef class Parser(TrainablePipe):
332
332
losses = losses,
333
333
beam_density = self .cfg[" beam_density" ]
334
334
)
335
+ oracle_histories = [self .moves.get_oracle_sequence(eg) for eg in examples]
335
336
max_moves = self .cfg[" update_with_oracle_cut_size" ]
336
337
if max_moves >= 1 :
337
338
# Chop sequences into lengths of this many words, to make the
338
339
# batch uniform length.
339
340
max_moves = int (random.uniform(max_moves // 2 , max_moves * 2 ))
340
341
states, golds, _ = self ._init_gold_batch(
341
342
examples,
343
+ oracle_histories,
342
344
max_length = max_moves
343
345
)
344
346
else :
@@ -370,11 +372,15 @@ cdef class Parser(TrainablePipe):
370
372
if sgd not in (None , False ):
371
373
self .finish_update(sgd)
372
374
docs = [eg.predicted for eg in examples]
373
- # TODO: Refactor so we don't have to parse twice like this (ugh)
375
+ # If we want to set the annotations based on predictions, it's really
376
+ # hard to avoid parsing the data twice :(.
374
377
# The issue is that we cut up the gold batch into sub-states, and that
375
- # makes it hard to get the actual predicted transition sequence.
376
- predicted_states = self .predict(docs)
377
- self .set_annotations(docs, predicted_states)
378
+ # means there's no one predicted sequence during the update.
379
+ gold_states = [
380
+ self .moves.follow_history(doc, history)
381
+ for doc, history in zip (docs, oracle_histories)
382
+ ]
383
+ self .set_annotations(docs, gold_states)
378
384
# Ugh, this is annoying. If we're working on GPU, we want to free the
379
385
# memory ASAP. It seems that Python doesn't necessarily get around to
380
386
# removing these in time if we don't explicitly delete? It's confusing.
@@ -581,7 +587,7 @@ cdef class Parser(TrainablePipe):
581
587
raise ValueError (Errors.E149) from None
582
588
return self
583
589
584
- def _init_gold_batch (self , examples , max_length ):
590
+ def _init_gold_batch (self , examples , oracle_histories , max_length ):
585
591
""" Make a square batch, of length equal to the shortest transition
586
592
sequence or a cap. A long
587
593
doc will get multiple states. Let's say we have a doc of length 2*N,
@@ -594,24 +600,17 @@ cdef class Parser(TrainablePipe):
594
600
all_states = self .moves.init_batch([eg.predicted for eg in examples])
595
601
states = []
596
602
golds = []
597
- to_cut = []
598
- for state, eg in zip (all_states, examples):
599
- if self .moves.has_gold(eg) and not state.is_final():
600
- gold = self .moves.init_gold(state, eg)
601
- if len (eg.x) < max_length:
602
- states.append(state)
603
- golds.append(gold)
604
- else :
605
- oracle_actions = self .moves.get_oracle_sequence_from_state(
606
- state.copy(), gold)
607
- to_cut.append((eg, state, gold, oracle_actions))
608
- if not to_cut:
609
- return states, golds, 0
610
- cdef int clas
611
- for eg, state, gold, oracle_actions in to_cut:
612
- for i in range (0 , len (oracle_actions), max_length):
603
+ for state, eg, history in zip (all_states, examples, oracle_histories):
604
+ if state.is_final():
605
+ continue
606
+ gold = self .moves.init_gold(state, eg)
607
+ if len (history) < max_length:
608
+ states.append(state)
609
+ golds.append(gold)
610
+ continue
611
+ for i in range (0 , len (history), max_length):
613
612
start_state = state.copy()
614
- for clas in oracle_actions [i:i+ max_length]:
613
+ for clas in history [i:i+ max_length]:
615
614
action = self .moves.c[clas]
616
615
action.do(state.c, action.label)
617
616
if state.is_final():
0 commit comments