Skip to content

Commit 5b2440a

Browse files
committed
Try to use real histories, not oracle
1 parent c3c462e commit 5b2440a

File tree

5 files changed

+30
-19
lines changed

5 files changed

+30
-19
lines changed

spacy/pipeline/_parser_internals/_state.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ cdef cppclass StateC:
3232
vector[ArcC] _left_arcs
3333
vector[ArcC] _right_arcs
3434
vector[libcpp.bool] _unshiftable
35+
vector[int] history
3536
set[int] _sent_starts
3637
TokenC _empty_token
3738
int length
@@ -382,3 +383,4 @@ cdef cppclass StateC:
382383
this._b_i = src._b_i
383384
this.offset = src.offset
384385
this._empty_token = src._empty_token
386+
this.history = src.history

spacy/pipeline/_parser_internals/arc_eager.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ cdef class ArcEager(TransitionSystem):
844844
state.print_state()
845845
)))
846846
action.do(state.c, action.label)
847+
state.c.history.push_back(i)
847848
break
848849
else:
849850
failed = False

spacy/pipeline/_parser_internals/stateclass.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ cdef class StateClass:
2020
if self._borrowed != 1:
2121
del self.c
2222

23+
@property
24+
def history(self):
25+
return list(self.c.history)
26+
2327
@property
2428
def stack(self):
2529
return [self.S(i) for i in range(self.c.stack_depth())]

spacy/pipeline/_parser_internals/transition_system.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ cdef class TransitionSystem:
6767
for clas in history:
6868
action = self.c[clas]
6969
action.do(state.c, action.label)
70+
state.c.history.push_back(clas)
7071
return state
7172

7273
def get_oracle_sequence(self, Example example, _debug=False):
@@ -110,6 +111,7 @@ cdef class TransitionSystem:
110111
"S0 head?", str(state.has_head(state.S(0))),
111112
)))
112113
action.do(state.c, action.label)
114+
state.c.history.push_back(i)
113115
break
114116
else:
115117
if _debug:
@@ -137,6 +139,7 @@ cdef class TransitionSystem:
137139
raise ValueError(Errors.E170.format(name=name))
138140
action = self.lookup_transition(name)
139141
action.do(state.c, action.label)
142+
state.c.history.push_back(action.clas)
140143

141144
cdef Transition lookup_transition(self, object name) except *:
142145
raise NotImplementedError

spacy/pipeline/transition_parser.pyx

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,21 @@ cdef class Parser(TrainablePipe):
203203
)
204204

205205
def greedy_parse(self, docs, drop=0.):
206-
cdef vector[StateC*] states
207-
cdef StateClass state
208206
set_dropout_rate(self.model, drop)
209-
batch = self.moves.init_batch(docs)
210207
# This is pretty dirty, but the NER can resize itself in init_batch,
211208
# if labels are missing. We therefore have to check whether we need to
212209
# expand our model output.
213210
self._resize()
214211
model = self.model.predict(docs)
212+
batch = self.moves.init_batch(docs)
213+
states = self._predict_states(model, batch)
214+
model.clear_memory()
215+
del model
216+
return states
217+
218+
def _predict_states(self, model, batch):
219+
cdef vector[StateC*] states
220+
cdef StateClass state
215221
weights = get_c_weights(model)
216222
for state in batch:
217223
if not state.is_final():
@@ -220,8 +226,6 @@ cdef class Parser(TrainablePipe):
220226
with nogil:
221227
self._parseC(&states[0],
222228
weights, sizes)
223-
model.clear_memory()
224-
del model
225229
return batch
226230

227231
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
@@ -306,6 +310,7 @@ cdef class Parser(TrainablePipe):
306310
else:
307311
action = self.moves.c[guess]
308312
action.do(states[i], action.label)
313+
states[i].history.push_back(guess)
309314
free(is_valid)
310315

311316
def update(self, examples, *, drop=0., sgd=None, losses=None):
@@ -319,7 +324,7 @@ cdef class Parser(TrainablePipe):
319324
# We need to take care to act on the whole batch, because we might be
320325
# getting vectors via a listener.
321326
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
322-
if len(examples) == 0:
327+
if n_examples == 0:
323328
return losses
324329
set_dropout_rate(self.model, drop)
325330
# The probability we use beam update, instead of falling back to
@@ -333,23 +338,25 @@ cdef class Parser(TrainablePipe):
333338
losses=losses,
334339
beam_density=self.cfg["beam_density"]
335340
)
336-
oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
341+
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
342+
final_states = self.moves.init_batch([eg.x for eg in examples])
343+
self._predict_states(model, final_states)
344+
histories = [list(state.history) for state in final_states]
345+
#oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples]
337346
max_moves = self.cfg["update_with_oracle_cut_size"]
338347
if max_moves >= 1:
339348
# Chop sequences into lengths of this many words, to make the
340349
# batch uniform length.
341350
max_moves = int(random.uniform(max_moves // 2, max_moves * 2))
342351
states, golds, _ = self._init_gold_batch(
343352
examples,
344-
oracle_histories,
353+
histories,
345354
max_length=max_moves
346355
)
347356
else:
348357
states, golds, _ = self.moves.init_gold_batch(examples)
349358
if not states:
350359
return losses
351-
docs = [eg.predicted for eg in examples]
352-
model, backprop_tok2vec = self.model.begin_update(docs)
353360

354361
all_states = list(states)
355362
states_golds = list(zip(states, golds))
@@ -373,15 +380,7 @@ cdef class Parser(TrainablePipe):
373380
backprop_tok2vec(golds)
374381
if sgd not in (None, False):
375382
self.finish_update(sgd)
376-
# If we want to set the annotations based on predictions, it's really
377-
# hard to avoid parsing the data twice :(.
378-
# The issue is that we cut up the gold batch into sub-states, and that
379-
# means there's no one predicted sequence during the update.
380-
gold_states = [
381-
self.moves.follow_history(doc, history)
382-
for doc, history in zip(docs, oracle_histories)
383-
]
384-
self.set_annotations(docs, gold_states)
383+
self.set_annotations([eg.x for eg in examples], final_states)
385384
# Ugh, this is annoying. If we're working on GPU, we want to free the
386385
# memory ASAP. It seems that Python doesn't necessarily get around to
387386
# removing these in time if we don't explicitly delete? It's confusing.
@@ -599,6 +598,7 @@ cdef class Parser(TrainablePipe):
599598
StateClass state
600599
Transition action
601600
all_states = self.moves.init_batch([eg.predicted for eg in examples])
601+
assert len(all_states) == len(examples) == len(oracle_histories)
602602
states = []
603603
golds = []
604604
for state, eg, history in zip(all_states, examples, oracle_histories):
@@ -616,6 +616,7 @@ cdef class Parser(TrainablePipe):
616616
for clas in history[i:i+max_length]:
617617
action = self.moves.c[clas]
618618
action.do(state.c, action.label)
619+
state.c.history.push_back(clas)
619620
if state.is_final():
620621
break
621622
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):

0 commit comments

Comments
 (0)