Skip to content

Commit 68d7979

Browse files
committed
add test for vocab after serializing KB
1 parent 539b0c1 commit 68d7979

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

spacy/pipeline/trainable_pipe.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ cdef class TrainablePipe(Pipe):
243243
def _validate_serialization_attrs(self):
244244
"""Check that the pipe implements the required attributes. If a subclass
245245
implements a custom __init__ method but doesn't set these attributes,
246-
the currently default to None, so we need to perform additonal checks.
246+
they currently default to None, so we need to perform additonal checks.
247247
"""
248248
if not hasattr(self, "vocab") or self.vocab is None:
249249
raise ValueError(Errors.E899.format(name=util.get_object_name(self)))

spacy/tests/pipeline/test_entity_linker.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from spacy.vocab import Vocab
66

77
from spacy import util, registry
8+
from spacy.ml import load_kb
89
from spacy.scorer import Scorer
910
from spacy.training import Example
1011
from spacy.lang.en import English
@@ -215,7 +216,7 @@ def create_kb(vocab):
215216
return kb
216217

217218
# run an EL pipe without a trained context encoder, to check the candidate generation step only
218-
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False},)
219+
entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False})
219220
entity_linker.set_kb(create_kb)
220221
# With the default get_candidates function, matching is case-sensitive
221222
text = "Douglas and douglas are not the same."
@@ -496,6 +497,31 @@ def create_kb(vocab):
496497
assert predictions == GOLD_entities
497498

498499

500+
def test_kb_serialization():
501+
# Test that the KB can be used in a pipeline with a different vocab
502+
vector_length = 3
503+
with make_tempdir() as tmp_dir:
504+
kb_dir = tmp_dir / "kb"
505+
nlp1 = English()
506+
assert "Q2146908" not in nlp1.vocab.strings
507+
mykb = KnowledgeBase(nlp1.vocab, entity_vector_length=vector_length)
508+
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
509+
mykb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
510+
assert "Q2146908" in nlp1.vocab.strings
511+
mykb.to_disk(kb_dir)
512+
513+
nlp2 = English()
514+
nlp2.vocab.strings.add("RandomWord")
515+
assert "RandomWord" in nlp2.vocab.strings
516+
assert "Q2146908" not in nlp2.vocab.strings
517+
518+
# Create the Entity Linker component with the KB from file, and check the final vocab
519+
entity_linker = nlp2.add_pipe("entity_linker", last=True)
520+
entity_linker.set_kb(load_kb(kb_dir))
521+
assert "Q2146908" in nlp2.vocab.strings
522+
assert "RandomWord" in nlp2.vocab.strings
523+
524+
499525
def test_scorer_links():
500526
train_examples = []
501527
nlp = English()

0 commit comments

Comments
 (0)