@@ -337,21 +337,22 @@ cdef class Parser(TrainablePipe):
337
337
# Chop sequences into lengths of this many words, to make the
338
338
# batch uniform length.
339
339
max_moves = int (random.uniform(max_moves // 2 , max_moves * 2 ))
340
- states, golds, _ = self ._init_gold_batch(
340
+ states, golds, max_moves, state2doc = self ._init_gold_batch(
341
341
examples,
342
342
max_length = max_moves
343
343
)
344
344
else :
345
- states, golds, _ = self .moves.init_gold_batch(examples)
345
+ states, golds, state2doc = self .moves.init_gold_batch(examples)
346
346
if not states:
347
347
return losses
348
348
model, backprop_tok2vec = self .model.begin_update([eg.x for eg in examples])
349
349
350
+ histories = [[] for example in examples]
350
351
all_states = list (states)
351
- states_golds = list (zip (states, golds))
352
+ states_golds = list (zip (states, golds, state2doc ))
352
353
n_moves = 0
353
354
while states_golds:
354
- states, golds = zip (* states_golds)
355
+ states, golds, state2doc = zip (* states_golds)
355
356
scores, backprop = model.begin_update(states)
356
357
d_scores = self .get_batch_loss(states, golds, scores, losses)
357
358
# Note that the gradient isn't normalized by the batch size
@@ -360,8 +361,13 @@ cdef class Parser(TrainablePipe):
360
361
# be getting smaller gradients for states in long sequences.
361
362
backprop(d_scores)
362
363
# Follow the predicted action
363
- self .transition_states(states, scores)
364
- states_golds = [(s, g) for (s, g) in zip (states, golds) if not s.is_final()]
364
+ actions = self .transition_states(states, scores)
365
+ for i, action in enumerate (actions):
366
+ histories[i].append(action)
367
+ states_golds = [
368
+ s for s in zip (states, golds, state2doc)
369
+ if not s[0 ].is_final()
370
+ ]
365
371
if max_moves >= 1 and n_moves >= max_moves:
366
372
break
367
373
n_moves += 1
@@ -370,11 +376,11 @@ cdef class Parser(TrainablePipe):
370
376
if sgd not in (None , False ):
371
377
self .finish_update(sgd)
372
378
docs = [eg.predicted for eg in examples]
373
- # TODO: Refactor so we don't have to parse twice like this (ugh)
374
- # The issue is that we cut up the gold batch into sub-states, and that
375
- # makes it hard to get the actual predicted transition sequence.
376
- predicted_states = self .predict(docs)
377
- self .set_annotations(docs, predicted_states )
379
+ states = [
380
+ self .moves.follow_history(doc, history)
381
+ for doc, history in zip (docs, histories)
382
+ ]
383
+ self .set_annotations(docs, self ._get_states(docs, states) )
378
384
# Ugh, this is annoying. If we're working on GPU, we want to free the
379
385
# memory ASAP. It seems that Python doesn't necessarily get around to
380
386
# removing these in time if we don't explicitly delete? It's confusing.
@@ -435,13 +441,16 @@ cdef class Parser(TrainablePipe):
435
441
436
442
def update_beam (self , examples , *, beam_width ,
437
443
drop = 0. , sgd = None , losses = None , beam_density = 0.0 ):
438
- states, golds, _ = self .moves.init_gold_batch(examples)
444
+ if losses is None :
445
+ losses = {}
446
+ losses.setdefault(self .name, 0.0 )
447
+ states, golds, docs = self .moves.init_gold_batch(examples)
439
448
if not states:
440
449
return losses
441
450
# Prepare the stepwise model, and get the callback for finishing the batch
442
451
model, backprop_tok2vec = self .model.begin_update(
443
452
[eg.predicted for eg in examples])
444
- loss = _beam_utils.update_beam(
453
+ predicted_histories, loss = _beam_utils.update_beam(
445
454
self .moves,
446
455
states,
447
456
golds,
@@ -453,6 +462,12 @@ cdef class Parser(TrainablePipe):
453
462
backprop_tok2vec(golds)
454
463
if sgd is not None :
455
464
self .finish_update(sgd)
465
+ states = [
466
+ self .moves.follow_history(doc, history)
467
+ for doc, history in zip (docs, predicted_histories)
468
+ ]
469
+ self .set_annotations(docs, states)
470
+ return losses
456
471
457
472
def get_batch_loss (self , states , golds , float[:, ::1] scores , losses ):
458
473
cdef StateClass state
@@ -595,18 +610,24 @@ cdef class Parser(TrainablePipe):
595
610
states = []
596
611
golds = []
597
612
to_cut = []
613
+ # Return a list indicating the position in the batch that each state
614
+ # refers to. This lets us put together the full list of predicted
615
+ # histories.
616
+ state2doc = []
617
+ doc2i = {eg.x: i for i, eg in enumerate (examples)}
598
618
for state, eg in zip (all_states, examples):
599
619
if self .moves.has_gold(eg) and not state.is_final():
600
620
gold = self .moves.init_gold(state, eg)
601
621
if len (eg.x) < max_length:
602
622
states.append(state)
603
623
golds.append(gold)
624
+ state2doc.append(doc2i[eg.x])
604
625
else :
605
626
oracle_actions = self .moves.get_oracle_sequence_from_state(
606
627
state.copy(), gold)
607
628
to_cut.append((eg, state, gold, oracle_actions))
608
629
if not to_cut:
609
- return states, golds, 0
630
+ return states, golds, 0 , state2doc
610
631
cdef int clas
611
632
for eg, state, gold, oracle_actions in to_cut:
612
633
for i in range (0 , len (oracle_actions), max_length):
@@ -619,6 +640,7 @@ cdef class Parser(TrainablePipe):
619
640
if self .moves.has_gold(eg, start_state.B(0 ), state.B(0 )):
620
641
states.append(start_state)
621
642
golds.append(gold)
643
+ state2doc.append(doc2i[eg.x])
622
644
if state.is_final():
623
645
break
624
- return states, golds, max_length
646
+ return states, golds, max_length, state2doc
0 commit comments