|
5 | 5 | from spacy.vocab import Vocab
|
6 | 6 |
|
7 | 7 | from spacy import util, registry
|
| 8 | +from spacy.ml import load_kb |
8 | 9 | from spacy.scorer import Scorer
|
9 | 10 | from spacy.training import Example
|
10 | 11 | from spacy.lang.en import English
|
@@ -215,7 +216,7 @@ def create_kb(vocab):
|
215 | 216 | return kb
|
216 | 217 |
|
217 | 218 | # run an EL pipe without a trained context encoder, to check the candidate generation step only
|
218 |
| - entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False},) |
| 219 | + entity_linker = nlp.add_pipe("entity_linker", config={"incl_context": False}) |
219 | 220 | entity_linker.set_kb(create_kb)
|
220 | 221 | # With the default get_candidates function, matching is case-sensitive
|
221 | 222 | text = "Douglas and douglas are not the same."
|
@@ -496,6 +497,31 @@ def create_kb(vocab):
|
496 | 497 | assert predictions == GOLD_entities
|
497 | 498 |
|
498 | 499 |
|
| 500 | +def test_kb_serialization(): |
| 501 | + # Test that the KB can be used in a pipeline with a different vocab |
| 502 | + vector_length = 3 |
| 503 | + with make_tempdir() as tmp_dir: |
| 504 | + kb_dir = tmp_dir / "kb" |
| 505 | + nlp1 = English() |
| 506 | + assert "Q2146908" not in nlp1.vocab.strings |
| 507 | + mykb = KnowledgeBase(nlp1.vocab, entity_vector_length=vector_length) |
| 508 | + mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) |
| 509 | + mykb.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8]) |
| 510 | + assert "Q2146908" in nlp1.vocab.strings |
| 511 | + mykb.to_disk(kb_dir) |
| 512 | + |
| 513 | + nlp2 = English() |
| 514 | + nlp2.vocab.strings.add("RandomWord") |
| 515 | + assert "RandomWord" in nlp2.vocab.strings |
| 516 | + assert "Q2146908" not in nlp2.vocab.strings |
| 517 | + |
| 518 | + # Create the Entity Linker component with the KB from file, and check the final vocab |
| 519 | + entity_linker = nlp2.add_pipe("entity_linker", last=True) |
| 520 | + entity_linker.set_kb(load_kb(kb_dir)) |
| 521 | + assert "Q2146908" in nlp2.vocab.strings |
| 522 | + assert "RandomWord" in nlp2.vocab.strings |
| 523 | + |
| 524 | + |
499 | 525 | def test_scorer_links():
|
500 | 526 | train_examples = []
|
501 | 527 | nlp = English()
|
|
0 commit comments