Skip to content

Commit c6df0ea

Browse files
committed
Fix set_annotations in parser.update
1 parent bb15d5b commit c6df0ea

File tree

4 files changed

+55
-18
lines changed

4 files changed

+55
-18
lines changed

spacy/pipeline/_parser_internals/_beam_utils.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
193193
for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)):
194194
loss += (d_scores**2).mean()
195195
bp_scores(d_scores)
196-
return loss
196+
# Return the predicted sequence for each doc.
197+
predicted_histories = []
198+
for i in range(len(pbeam)):
199+
predicted_histories.append(pbeam[i].histories[0])
200+
return predicted_histories, loss
197201

198202

199203
def collect_states(beams, docs):

spacy/pipeline/_parser_internals/arc_eager.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,16 +638,17 @@ cdef class ArcEager(TransitionSystem):
638638
return gold
639639

640640
def init_gold_batch(self, examples):
641-
# TODO: Projectivity?
642641
all_states = self.init_batch([eg.predicted for eg in examples])
643642
golds = []
644643
states = []
644+
docs = []
645645
for state, eg in zip(all_states, examples):
646646
if self.has_gold(eg) and not state.is_final():
647647
golds.append(self.init_gold(state, eg))
648648
states.append(state)
649+
docs.append(eg.x)
649650
n_steps = sum([len(s.queue) for s in states])
650-
return states, golds, n_steps
651+
return states, golds, docs
651652

652653
def _replace_unseen_labels(self, ArcEagerGold gold):
653654
backoff_label = self.strings["dep"]

spacy/pipeline/_parser_internals/transition_system.pyx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ cdef class TransitionSystem:
120120
raise ValueError(Errors.E024)
121121
return history
122122

123+
def follow_history(self, doc, history):
124+
"""Get the state that results from following a sequence of actions."""
125+
cdef int clas
126+
cdef StateClass state
127+
state = self.init_batch([doc])[0]
128+
for clas in history:
129+
action = self.c[clas]
130+
action.do(state.c, action.label)
131+
return state
132+
123133
def apply_transition(self, StateClass state, name):
124134
if not self.is_valid(state, name):
125135
raise ValueError(Errors.E170.format(name=name))

spacy/pipeline/transition_parser.pyx

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -337,21 +337,22 @@ cdef class Parser(TrainablePipe):
337337
# Chop sequences into lengths of this many words, to make the
338338
# batch uniform length.
339339
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
340-
states, golds, _ = self._init_gold_batch(
340+
states, golds, max_moves, state2doc = self._init_gold_batch(
341341
examples,
342342
max_length=max_moves
343343
)
344344
else:
345-
states, golds, _ = self.moves.init_gold_batch(examples)
345+
states, golds, state2doc = self.moves.init_gold_batch(examples)
346346
if not states:
347347
return losses
348348
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
349349

350+
histories = [[] for example in examples]
350351
all_states = list(states)
351-
states_golds = list(zip(states, golds))
352+
states_golds = list(zip(states, golds, state2doc))
352353
n_moves = 0
353354
while states_golds:
354-
states, golds = zip(*states_golds)
355+
states, golds, state2doc = zip(*states_golds)
355356
scores, backprop = model.begin_update(states)
356357
d_scores = self.get_batch_loss(states, golds, scores, losses)
357358
# Note that the gradient isn't normalized by the batch size
@@ -360,8 +361,13 @@ cdef class Parser(TrainablePipe):
360361
# be getting smaller gradients for states in long sequences.
361362
backprop(d_scores)
362363
# Follow the predicted action
363-
self.transition_states(states, scores)
364-
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
364+
actions = self.transition_states(states, scores)
365+
for i, action in enumerate(actions):
366+
histories[i].append(action)
367+
states_golds = [
368+
s for s in zip(states, golds, state2doc)
369+
if not s[0].is_final()
370+
]
365371
if max_moves >= 1 and n_moves >= max_moves:
366372
break
367373
n_moves += 1
@@ -370,11 +376,11 @@ cdef class Parser(TrainablePipe):
370376
if sgd not in (None, False):
371377
self.finish_update(sgd)
372378
docs = [eg.predicted for eg in examples]
373-
# TODO: Refactor so we don't have to parse twice like this (ugh)
374-
# The issue is that we cut up the gold batch into sub-states, and that
375-
# makes it hard to get the actual predicted transition sequence.
376-
predicted_states = self.predict(docs)
377-
self.set_annotations(docs, predicted_states)
379+
states = [
380+
self.moves.follow_history(doc, history)
381+
for doc, history in zip(docs, histories)
382+
]
383+
self.set_annotations(docs, self._get_states(docs, states))
378384
# Ugh, this is annoying. If we're working on GPU, we want to free the
379385
# memory ASAP. It seems that Python doesn't necessarily get around to
380386
# removing these in time if we don't explicitly delete? It's confusing.
@@ -435,13 +441,16 @@ cdef class Parser(TrainablePipe):
435441

436442
def update_beam(self, examples, *, beam_width,
437443
drop=0., sgd=None, losses=None, beam_density=0.0):
438-
states, golds, _ = self.moves.init_gold_batch(examples)
444+
if losses is None:
445+
losses = {}
446+
losses.setdefault(self.name, 0.0)
447+
states, golds, docs = self.moves.init_gold_batch(examples)
439448
if not states:
440449
return losses
441450
# Prepare the stepwise model, and get the callback for finishing the batch
442451
model, backprop_tok2vec = self.model.begin_update(
443452
[eg.predicted for eg in examples])
444-
loss = _beam_utils.update_beam(
453+
predicted_histories, loss = _beam_utils.update_beam(
445454
self.moves,
446455
states,
447456
golds,
@@ -453,6 +462,12 @@ cdef class Parser(TrainablePipe):
453462
backprop_tok2vec(golds)
454463
if sgd is not None:
455464
self.finish_update(sgd)
465+
states = [
466+
self.moves.follow_history(doc, history)
467+
for doc, history in zip(docs, predicted_histories)
468+
]
469+
self.set_annotations(docs, states)
470+
return losses
456471

457472
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
458473
cdef StateClass state
@@ -595,18 +610,24 @@ cdef class Parser(TrainablePipe):
595610
states = []
596611
golds = []
597612
to_cut = []
613+
# Return a list indicating the position in the batch that each state
614+
# refers to. This lets us put together the full list of predicted
615+
# histories.
616+
state2doc = []
617+
doc2i = {eg.x: i for i, eg in enumerate(examples)}
598618
for state, eg in zip(all_states, examples):
599619
if self.moves.has_gold(eg) and not state.is_final():
600620
gold = self.moves.init_gold(state, eg)
601621
if len(eg.x) < max_length:
602622
states.append(state)
603623
golds.append(gold)
624+
state2doc.append(doc2i[eg.x])
604625
else:
605626
oracle_actions = self.moves.get_oracle_sequence_from_state(
606627
state.copy(), gold)
607628
to_cut.append((eg, state, gold, oracle_actions))
608629
if not to_cut:
609-
return states, golds, 0
630+
return states, golds, 0, state2doc
610631
cdef int clas
611632
for eg, state, gold, oracle_actions in to_cut:
612633
for i in range(0, len(oracle_actions), max_length):
@@ -619,6 +640,7 @@ cdef class Parser(TrainablePipe):
619640
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
620641
states.append(start_state)
621642
golds.append(gold)
643+
state2doc.append(doc2i[eg.x])
622644
if state.is_final():
623645
break
624-
return states, golds, max_length
646+
return states, golds, max_length, state2doc

0 commit comments

Comments
 (0)