Skip to content

Commit e680efc

Browse files
authored
Set annotations in update (explosion#6767)
* bump to 3.0.0rc4 * do set_annotations in component update calls * update docs and remove set_annotations flag * fix EL test
1 parent 57640aa commit e680efc

21 files changed

+57
-77
lines changed

spacy/about.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# fmt: off
22
__title__ = "spacy-nightly"
3-
__version__ = "3.0.0rc3"
3+
__version__ = "3.0.0rc4"
44
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
55
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
66
__projects__ = "https://github.com/explosion/projects"

spacy/pipeline/entity_linker.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,18 +193,16 @@ def update(
193193
self,
194194
examples: Iterable[Example],
195195
*,
196-
set_annotations: bool = False,
197196
drop: float = 0.0,
198197
sgd: Optional[Optimizer] = None,
199198
losses: Optional[Dict[str, float]] = None,
200199
) -> Dict[str, float]:
201200
"""Learn from a batch of documents and gold-standard information,
202-
updating the pipe's model. Delegates to predict and get_loss.
201+
updating the pipe's model. Delegates to predict, get_loss and
202+
set_annotations.
203203
204204
examples (Iterable[Example]): A batch of Example objects.
205205
drop (float): The dropout rate.
206-
set_annotations (bool): Whether or not to update the Example objects
207-
with the predictions.
208206
sgd (thinc.api.Optimizer): The optimizer.
209207
losses (Dict[str, float]): Optional record of the loss during training.
210208
Updated using the component name as the key.
@@ -220,11 +218,13 @@ def update(
220218
return losses
221219
validate_examples(examples, "EntityLinker.update")
222220
sentence_docs = []
223-
docs = [eg.predicted for eg in examples]
224-
if set_annotations:
225-
# This seems simpler than other ways to get that exact output -- but
226-
# it does run the model twice :(
227-
predictions = self.model.predict(docs)
221+
docs = []
222+
for eg in examples:
223+
eg.predicted.ents = eg.reference.ents
224+
docs.append(eg.predicted)
225+
# This seems simpler than other ways to get that exact output -- but
226+
# it does run the model twice :(
227+
predictions = self.predict(docs)
228228
for eg in examples:
229229
sentences = [s for s in eg.reference.sents]
230230
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
@@ -260,8 +260,7 @@ def update(
260260
if sgd is not None:
261261
self.finish_update(sgd)
262262
losses[self.name] += loss
263-
if set_annotations:
264-
self.set_annotations(docs, predictions)
263+
self.set_annotations(docs, predictions)
265264
return losses
266265

267266
def get_loss(self, examples: Iterable[Example], sentence_encodings):

spacy/pipeline/multitask.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class ClozeMultitask(TrainablePipe):
199199
loss = self.distance.get_loss(prediction, target)
200200
return loss, gradient
201201

202-
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
202+
def update(self, examples, *, drop=0., sgd=None, losses=None):
203203
pass
204204

205205
def rehearse(self, examples, drop=0., sgd=None, losses=None):

spacy/pipeline/tagger.pyx

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,13 @@ class Tagger(TrainablePipe):
173173
if doc.c[j].tag == 0:
174174
doc.c[j].tag = self.vocab.strings[self.labels[tag_id]]
175175

176-
def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
176+
def update(self, examples, *, drop=0., sgd=None, losses=None):
177177
"""Learn from a batch of documents and gold-standard information,
178-
updating the pipe's model. Delegates to predict and get_loss.
178+
updating the pipe's model. Delegates to predict, get_loss and
179+
set_annotations.
179180
180181
examples (Iterable[Example]): A batch of Example objects.
181182
drop (float): The dropout rate.
182-
set_annotations (bool): Whether or not to update the Example objects
183-
with the predictions.
184183
sgd (thinc.api.Optimizer): The optimizer.
185184
losses (Dict[str, float]): Optional record of the loss during training.
186185
Updated using the component name as the key.
@@ -206,9 +205,8 @@ class Tagger(TrainablePipe):
206205
self.finish_update(sgd)
207206

208207
losses[self.name] += loss
209-
if set_annotations:
210-
docs = [eg.predicted for eg in examples]
211-
self.set_annotations(docs, self._scores2guesses(tag_scores))
208+
docs = [eg.predicted for eg in examples]
209+
self.set_annotations(docs, self._scores2guesses(tag_scores))
212210
return losses
213211

214212
def rehearse(self, examples, *, drop=0., sgd=None, losses=None):

spacy/pipeline/textcat.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,15 @@ def update(
195195
examples: Iterable[Example],
196196
*,
197197
drop: float = 0.0,
198-
set_annotations: bool = False,
199198
sgd: Optional[Optimizer] = None,
200199
losses: Optional[Dict[str, float]] = None,
201200
) -> Dict[str, float]:
202201
"""Learn from a batch of documents and gold-standard information,
203-
updating the pipe's model. Delegates to predict and get_loss.
202+
updating the pipe's model. Delegates to predict, get_loss and
203+
set_annotations.
204204
205205
examples (Iterable[Example]): A batch of Example objects.
206206
drop (float): The dropout rate.
207-
set_annotations (bool): Whether or not to update the Example objects
208-
with the predictions.
209207
sgd (thinc.api.Optimizer): The optimizer.
210208
losses (Dict[str, float]): Optional record of the loss during training.
211209
Updated using the component name as the key.
@@ -228,9 +226,8 @@ def update(
228226
if sgd is not None:
229227
self.finish_update(sgd)
230228
losses[self.name] += loss
231-
if set_annotations:
232-
docs = [eg.predicted for eg in examples]
233-
self.set_annotations(docs, scores=scores)
229+
docs = [eg.predicted for eg in examples]
230+
self.set_annotations(docs, scores=scores)
234231
return losses
235232

236233
def rehearse(

spacy/pipeline/tok2vec.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,12 @@ def update(
163163
drop: float = 0.0,
164164
sgd: Optional[Optimizer] = None,
165165
losses: Optional[Dict[str, float]] = None,
166-
set_annotations: bool = False,
167166
):
168167
"""Learn from a batch of documents and gold-standard information,
169168
updating the pipe's model.
170169
171170
examples (Iterable[Example]): A batch of Example objects.
172171
drop (float): The dropout rate.
173-
set_annotations (bool): Whether or not to update the Example objects
174-
with the predictions.
175172
sgd (thinc.api.Optimizer): The optimizer.
176173
losses (Dict[str, float]): Optional record of the loss during training.
177174
Updated using the component name as the key.
@@ -210,8 +207,7 @@ def backprop(one_d_tokvecs):
210207
listener.receive(batch_id, tokvecs, accumulate_gradient)
211208
if self.listeners:
212209
self.listeners[-1].receive(batch_id, tokvecs, backprop)
213-
if set_annotations:
214-
self.set_annotations(docs, tokvecs)
210+
self.set_annotations(docs, tokvecs)
215211
return losses
216212

217213
def get_loss(self, examples, scores) -> None:

spacy/pipeline/trainable_pipe.pyx

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,14 @@ cdef class TrainablePipe(Pipe):
9191
def update(self,
9292
examples: Iterable["Example"],
9393
*, drop: float=0.0,
94-
set_annotations: bool=False,
9594
sgd: Optimizer=None,
9695
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
9796
"""Learn from a batch of documents and gold-standard information,
98-
updating the pipe's model. Delegates to predict and get_loss.
97+
updating the pipe's model. Delegates to predict, get_loss and
98+
set_annotations.
9999

100100
examples (Iterable[Example]): A batch of Example objects.
101101
drop (float): The dropout rate.
102-
set_annotations (bool): Whether or not to update the Example objects
103-
with the predictions.
104102
sgd (thinc.api.Optimizer): The optimizer.
105103
losses (Dict[str, float]): Optional record of the loss during training.
106104
Updated using the component name as the key.
@@ -124,9 +122,8 @@ cdef class TrainablePipe(Pipe):
124122
if sgd not in (None, False):
125123
self.finish_update(sgd)
126124
losses[self.name] += loss
127-
if set_annotations:
128-
docs = [eg.predicted for eg in examples]
129-
self.set_annotations(docs, scores=scores)
125+
docs = [eg.predicted for eg in examples]
126+
self.set_annotations(docs, scores=scores)
130127
return losses
131128
132129
def rehearse(self,

spacy/pipeline/transition_parser.pyx

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ cdef class Parser(TrainablePipe):
308308
action.do(states[i], action.label)
309309
free(is_valid)
310310

311-
def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
311+
def update(self, examples, *, drop=0., sgd=None, losses=None):
312312
cdef StateClass state
313313
if losses is None:
314314
losses = {}
@@ -328,7 +328,6 @@ cdef class Parser(TrainablePipe):
328328
return self.update_beam(
329329
examples,
330330
beam_width=self.cfg["beam_width"],
331-
set_annotations=set_annotations,
332331
sgd=sgd,
333332
losses=losses,
334333
beam_density=self.cfg["beam_density"]
@@ -370,9 +369,8 @@ cdef class Parser(TrainablePipe):
370369
backprop_tok2vec(golds)
371370
if sgd not in (None, False):
372371
self.finish_update(sgd)
373-
if set_annotations:
374-
docs = [eg.predicted for eg in examples]
375-
self.set_annotations(docs, all_states)
372+
docs = [eg.predicted for eg in examples]
373+
self.set_annotations(docs, all_states)
376374
# Ugh, this is annoying. If we're working on GPU, we want to free the
377375
# memory ASAP. It seems that Python doesn't necessarily get around to
378376
# removing these in time if we don't explicitly delete? It's confusing.
@@ -432,7 +430,7 @@ cdef class Parser(TrainablePipe):
432430
return losses
433431

434432
def update_beam(self, examples, *, beam_width,
435-
drop=0., sgd=None, losses=None, set_annotations=False, beam_density=0.0):
433+
drop=0., sgd=None, losses=None, beam_density=0.0):
436434
states, golds, _ = self.moves.init_gold_batch(examples)
437435
if not states:
438436
return losses

spacy/tests/pipeline/test_entity_linker.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def test_preserving_links_ents_2(nlp):
425425
def test_overfitting_IO():
426426
# Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly
427427
nlp = English()
428+
nlp.add_pipe("sentencizer", first=True)
428429
vector_length = 3
429430
assert "Q2146908" not in nlp.vocab.strings
430431

@@ -464,9 +465,6 @@ def create_kb(vocab):
464465
nlp.update(train_examples, sgd=optimizer, losses=losses)
465466
assert losses["entity_linker"] < 0.001
466467

467-
# adding additional components that are required for the entity_linker
468-
nlp.add_pipe("sentencizer", first=True)
469-
470468
# Add a custom component to recognize "Russ Cochran" as an entity for the example training data
471469
patterns = [
472470
{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]}

website/docs/api/dependencyparser.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,9 @@ Modify a batch of [`Doc`](/api/doc) objects, using pre-computed scores.
220220
## DependencyParser.update {#update tag="method"}
221221
222222
Learn from a batch of [`Example`](/api/example) objects, updating the pipe's
223-
model. Delegates to [`predict`](/api/dependencyparser#predict) and
224-
[`get_loss`](/api/dependencyparser#get_loss).
223+
model. Delegates to [`predict`](/api/dependencyparser#predict),
224+
[`get_loss`](/api/dependencyparser#get_loss) and
225+
[`set_annotations`](/api/dependencyparser#set_annotations).
225226
226227
> #### Example
227228
>
@@ -236,7 +237,6 @@ model. Delegates to [`predict`](/api/dependencyparser#predict) and
236237
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
237238
| _keyword-only_ | |
238239
| `drop` | The dropout rate. ~~float~~ |
239-
| `set_annotations` | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](#set_annotations). ~~bool~~ |
240240
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
241241
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
242242
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |

website/docs/api/entitylinker.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ entities.
237237
238238
Learn from a batch of [`Example`](/api/example) objects, updating both the
239239
pipe's entity linking model and context encoder. Delegates to
240-
[`predict`](/api/entitylinker#predict).
240+
[`predict`](/api/entitylinker#predict) and
241+
[`set_annotations`](/api/entitylinker#set_annotations).
241242
242243
> #### Example
243244
>
@@ -252,7 +253,6 @@ pipe's entity linking model and context encoder. Delegates to
252253
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
253254
| _keyword-only_ | |
254255
| `drop` | The dropout rate. ~~float~~ |
255-
| `set_annotations` | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](#set_annotations). ~~bool~~ |
256256
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
257257
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
258258
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |

website/docs/api/entityrecognizer.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ Modify a batch of [`Doc`](/api/doc) objects, using pre-computed scores.
209209
## EntityRecognizer.update {#update tag="method"}
210210
211211
Learn from a batch of [`Example`](/api/example) objects, updating the pipe's
212-
model. Delegates to [`predict`](/api/entityrecognizer#predict) and
213-
[`get_loss`](/api/entityrecognizer#get_loss).
212+
model. Delegates to [`predict`](/api/entityrecognizer#predict),
213+
[`get_loss`](/api/entityrecognizer#get_loss) and
214+
[`set_annotations`](/api/entityrecognizer#set_annotations).
214215
215216
> #### Example
216217
>
@@ -225,7 +226,6 @@ model. Delegates to [`predict`](/api/entityrecognizer#predict) and
225226
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
226227
| _keyword-only_ | |
227228
| `drop` | The dropout rate. ~~float~~ |
228-
| `set_annotations` | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](#set_annotations). ~~bool~~ |
229229
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
230230
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
231231
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |

website/docs/api/morphologizer.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,9 @@ Modify a batch of [`Doc`](/api/doc) objects, using pre-computed scores.
189189
190190
Learn from a batch of [`Example`](/api/example) objects containing the
191191
predictions and gold-standard annotations, and update the component's model.
192-
Delegates to [`predict`](/api/morphologizer#predict) and
193-
[`get_loss`](/api/morphologizer#get_loss).
192+
Delegates to [`predict`](/api/morphologizer#predict),
193+
[`get_loss`](/api/morphologizer#get_loss) and
194+
[`set_annotations`](/api/morphologizer#set_annotations).
194195
195196
> #### Example
196197
>
@@ -205,7 +206,6 @@ Delegates to [`predict`](/api/morphologizer#predict) and
205206
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
206207
| _keyword-only_ | |
207208
| `drop` | The dropout rate. ~~float~~ |
208-
| `set_annotations` | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](#set_annotations). ~~bool~~ |
209209
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
210210
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
211211
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |

website/docs/api/multilabel_textcategorizer.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,9 @@ Modify a batch of [`Doc`](/api/doc) objects using pre-computed scores.
199199
200200
Learn from a batch of [`Example`](/api/example) objects containing the
201201
predictions and gold-standard annotations, and update the component's model.
202-
Delegates to [`predict`](/api/multilabel_textcategorizer#predict) and
203-
[`get_loss`](/api/multilabel_textcategorizer#get_loss).
202+
Delegates to [`predict`](/api/multilabel_textcategorizer#predict),
203+
[`get_loss`](/api/multilabel_textcategorizer#get_loss) and
204+
[`set_annotations`](/api/multilabel_textcategorizer#set_annotations).
204205
205206
> #### Example
206207
>
@@ -215,7 +216,6 @@ Delegates to [`predict`](/api/multilabel_textcategorizer#predict) and
215216
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
216217
| _keyword-only_ | |
217218
| `drop` | The dropout rate. ~~float~~ |
218-
| `set_annotations` | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](#set_annotations). ~~bool~~ |
219219
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
220220
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
221221
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |

website/docs/api/pipe.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ predictions and gold-standard annotations, and update the component's model.
195195
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
196196
| _keyword-only_ | |
197197
| `drop` | The dropout rate. ~~float~~ |
198-
| `set_annotations` | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](#set_annotations). ~~bool~~ |
199198
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
200199
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
201200
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |

0 commit comments

Comments
 (0)