Skip to content

Commit be155ea

Browse files
committed
Fix set_annotations during parser update
1 parent c631c35 commit be155ea

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

spacy/pipeline/_parser_internals/transition_system.pyx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ cdef class TransitionSystem:
6161
offset += len(doc)
6262
return states
6363

64+
def follow_history(self, doc, history):
65+
cdef int clas
66+
cdef StateClass state = StateClass(doc)
67+
for clas in history:
68+
action = self.c[clas]
69+
action.do(state.c, action.label)
70+
return state
71+
6472
def get_oracle_sequence(self, Example example, _debug=False):
6573
states, golds, _ = self.init_gold_batch([example])
6674
if not states:

spacy/pipeline/transition_parser.pyx

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ cdef class Parser(TrainablePipe):
317317
for multitask in self._multitasks:
318318
multitask.update(examples, drop=drop, sgd=sgd)
319319

320-
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
321-
if n_examples == 0:
320+
examples = [eg for eg in examples if self.moves.has_gold(eg)]
321+
if len(examples) == 0:
322322
return losses
323323
set_dropout_rate(self.model, drop)
324324
# The probability we use beam update, instead of falling back to
@@ -332,13 +332,15 @@ cdef class Parser(TrainablePipe):
332332
losses=losses,
333333
beam_density=self.cfg["beam_density"]
334334
)
335+
oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
335336
max_moves = self.cfg["update_with_oracle_cut_size"]
336337
if max_moves >= 1:
337338
# Chop sequences into lengths of this many words, to make the
338339
# batch uniform length.
339340
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
340341
states, golds, _ = self._init_gold_batch(
341342
examples,
343+
oracle_histories,
342344
max_length=max_moves
343345
)
344346
else:
@@ -370,11 +372,15 @@ cdef class Parser(TrainablePipe):
370372
if sgd not in (None, False):
371373
self.finish_update(sgd)
372374
docs = [eg.predicted for eg in examples]
373-
# TODO: Refactor so we don't have to parse twice like this (ugh)
375+
# If we want to set the annotations based on predictions, it's really
376+
# hard to avoid parsing the data twice :(.
374377
# The issue is that we cut up the gold batch into sub-states, and that
375-
# makes it hard to get the actual predicted transition sequence.
376-
predicted_states = self.predict(docs)
377-
self.set_annotations(docs, predicted_states)
378+
# means there's no one predicted sequence during the update.
379+
gold_states = [
380+
self.moves.follow_history(doc, history)
381+
for doc, history in zip(docs, oracle_histories)
382+
]
383+
self.set_annotations(docs, gold_states)
378384
# Ugh, this is annoying. If we're working on GPU, we want to free the
379385
# memory ASAP. It seems that Python doesn't necessarily get around to
380386
# removing these in time if we don't explicitly delete? It's confusing.
@@ -581,7 +587,7 @@ cdef class Parser(TrainablePipe):
581587
raise ValueError(Errors.E149) from None
582588
return self
583589

584-
def _init_gold_batch(self, examples, max_length):
590+
def _init_gold_batch(self, examples, oracle_histories, max_length):
585591
"""Make a square batch, of length equal to the shortest transition
586592
sequence or a cap. A long
587593
doc will get multiple states. Let's say we have a doc of length 2*N,
@@ -594,24 +600,17 @@ cdef class Parser(TrainablePipe):
594600
all_states = self.moves.init_batch([eg.predicted for eg in examples])
595601
states = []
596602
golds = []
597-
to_cut = []
598-
for state, eg in zip(all_states, examples):
599-
if self.moves.has_gold(eg) and not state.is_final():
600-
gold = self.moves.init_gold(state, eg)
601-
if len(eg.x) < max_length:
602-
states.append(state)
603-
golds.append(gold)
604-
else:
605-
oracle_actions = self.moves.get_oracle_sequence_from_state(
606-
state.copy(), gold)
607-
to_cut.append((eg, state, gold, oracle_actions))
608-
if not to_cut:
609-
return states, golds, 0
610-
cdef int clas
611-
for eg, state, gold, oracle_actions in to_cut:
612-
for i in range(0, len(oracle_actions), max_length):
603+
for state, eg, history in zip(all_states, examples, oracle_histories):
604+
if state.is_final():
605+
continue
606+
gold = self.moves.init_gold(state, eg)
607+
if len(history) < max_length:
608+
states.append(state)
609+
golds.append(gold)
610+
continue
611+
for i in range(0, len(history), max_length):
613612
start_state = state.copy()
614-
for clas in oracle_actions[i:i+max_length]:
613+
for clas in history[i:i+max_length]:
615614
action = self.moves.c[clas]
616615
action.do(state.c, action.label)
617616
if state.is_final():

0 commit comments

Comments
 (0)