Skip to content

Commit dca2e8c

Browse files
authored
Minor NEL type fixes (explosion#10860)
* Fix TODO about typing Fix was simple: just request an array2f. * Add type ignore Maxout has a more restrictive type than the residual layer expects (only Floats2d vs any Floats). * Various cleanup This moves a lot of lines around but doesn't change any functionality. Details: 1. use `continue` to reduce indentation 2. move sentence doc building inside conditional since it's otherwise unused 3. reduces some temporary assignments
1 parent 56d4055 commit dca2e8c

File tree

2 files changed

+56
-56
lines changed

2 files changed

+56
-56
lines changed

spacy/ml/models/entity_linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def build_nel_encoder(
2323
((tok2vec >> list2ragged()) & build_span_maker())
2424
>> extract_spans()
2525
>> reduce_mean()
26-
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
26+
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore
2727
>> output_layer
2828
)
2929
model.set_ref("output_layer", output_layer)

spacy/pipeline/entity_linker.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
355355
keep_ents.append(eidx)
356356

357357
eidx += 1
358-
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
358+
entity_encodings = self.model.ops.asarray2f(entity_encodings, dtype="float32")
359359
selected_encodings = sentence_encodings[keep_ents]
360360

361361
# if there are no matches, short circuit
@@ -368,13 +368,12 @@ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
368368
method="get_loss", msg="gold entities do not match up"
369369
)
370370
raise RuntimeError(err)
371-
# TODO: fix typing issue here
372-
gradients = self.distance.get_grad(selected_encodings, entity_encodings) # type: ignore
371+
gradients = self.distance.get_grad(selected_encodings, entity_encodings)
373372
# to match the input size, we need to give a zero gradient for items not in the kb
374373
out = self.model.ops.alloc2f(*sentence_encodings.shape)
375374
out[keep_ents] = gradients
376375

377-
loss = self.distance.get_loss(selected_encodings, entity_encodings) # type: ignore
376+
loss = self.distance.get_loss(selected_encodings, entity_encodings)
378377
loss = loss / len(entity_encodings)
379378
return float(loss), out
380379

@@ -391,74 +390,75 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
391390
self.validate_kb()
392391
entity_count = 0
393392
final_kb_ids: List[str] = []
393+
xp = self.model.ops.xp
394394
if not docs:
395395
return final_kb_ids
396396
if isinstance(docs, Doc):
397397
docs = [docs]
398398
for i, doc in enumerate(docs):
399+
if len(doc) == 0:
400+
continue
399401
sentences = [s for s in doc.sents]
400-
if len(doc) > 0:
401-
# Looping through each entity (TODO: rewrite)
402-
for ent in doc.ents:
403-
sent = ent.sent
404-
sent_index = sentences.index(sent)
405-
assert sent_index >= 0
402+
# Looping through each entity (TODO: rewrite)
403+
for ent in doc.ents:
404+
sent_index = sentences.index(ent.sent)
405+
assert sent_index >= 0
406+
407+
if self.incl_context:
406408
# get n_neighbour sentences, clipped to the length of the document
407409
start_sentence = max(0, sent_index - self.n_sents)
408410
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
409411
start_token = sentences[start_sentence].start
410412
end_token = sentences[end_sentence].end
411413
sent_doc = doc[start_token:end_token].as_doc()
412414
# currently, the context is the same for each entity in a sentence (should be refined)
413-
xp = self.model.ops.xp
414-
if self.incl_context:
415-
sentence_encoding = self.model.predict([sent_doc])[0]
416-
sentence_encoding_t = sentence_encoding.T
417-
sentence_norm = xp.linalg.norm(sentence_encoding_t)
418-
entity_count += 1
419-
if ent.label_ in self.labels_discard:
420-
# ignoring this entity - setting to NIL
415+
sentence_encoding = self.model.predict([sent_doc])[0]
416+
sentence_encoding_t = sentence_encoding.T
417+
sentence_norm = xp.linalg.norm(sentence_encoding_t)
418+
entity_count += 1
419+
if ent.label_ in self.labels_discard:
420+
# ignoring this entity - setting to NIL
421+
final_kb_ids.append(self.NIL)
422+
else:
423+
candidates = list(self.get_candidates(self.kb, ent))
424+
if not candidates:
425+
# no prediction possible for this entity - setting to NIL
421426
final_kb_ids.append(self.NIL)
427+
elif len(candidates) == 1:
428+
# shortcut for efficiency reasons: take the 1 candidate
429+
# TODO: thresholding
430+
final_kb_ids.append(candidates[0].entity_)
422431
else:
423-
candidates = list(self.get_candidates(self.kb, ent))
424-
if not candidates:
425-
# no prediction possible for this entity - setting to NIL
426-
final_kb_ids.append(self.NIL)
427-
elif len(candidates) == 1:
428-
# shortcut for efficiency reasons: take the 1 candidate
429-
# TODO: thresholding
430-
final_kb_ids.append(candidates[0].entity_)
431-
else:
432-
random.shuffle(candidates)
433-
# set all prior probabilities to 0 if incl_prior=False
434-
prior_probs = xp.asarray([c.prior_prob for c in candidates])
435-
if not self.incl_prior:
436-
prior_probs = xp.asarray([0.0 for _ in candidates])
437-
scores = prior_probs
438-
# add in similarity from the context
439-
if self.incl_context:
440-
entity_encodings = xp.asarray(
441-
[c.entity_vector for c in candidates]
442-
)
443-
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
444-
if len(entity_encodings) != len(prior_probs):
445-
raise RuntimeError(
446-
Errors.E147.format(
447-
method="predict",
448-
msg="vectors not of equal length",
449-
)
432+
random.shuffle(candidates)
433+
# set all prior probabilities to 0 if incl_prior=False
434+
prior_probs = xp.asarray([c.prior_prob for c in candidates])
435+
if not self.incl_prior:
436+
prior_probs = xp.asarray([0.0 for _ in candidates])
437+
scores = prior_probs
438+
# add in similarity from the context
439+
if self.incl_context:
440+
entity_encodings = xp.asarray(
441+
[c.entity_vector for c in candidates]
442+
)
443+
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
444+
if len(entity_encodings) != len(prior_probs):
445+
raise RuntimeError(
446+
Errors.E147.format(
447+
method="predict",
448+
msg="vectors not of equal length",
450449
)
451-
# cosine similarity
452-
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
453-
sentence_norm * entity_norm
454450
)
455-
if sims.shape != prior_probs.shape:
456-
raise ValueError(Errors.E161)
457-
scores = prior_probs + sims - (prior_probs * sims)
458-
# TODO: thresholding
459-
best_index = scores.argmax().item()
460-
best_candidate = candidates[best_index]
461-
final_kb_ids.append(best_candidate.entity_)
451+
# cosine similarity
452+
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
453+
sentence_norm * entity_norm
454+
)
455+
if sims.shape != prior_probs.shape:
456+
raise ValueError(Errors.E161)
457+
scores = prior_probs + sims - (prior_probs * sims)
458+
# TODO: thresholding
459+
best_index = scores.argmax().item()
460+
best_candidate = candidates[best_index]
461+
final_kb_ids.append(best_candidate.entity_)
462462
if not (len(final_kb_ids) == entity_count):
463463
err = Errors.E147.format(
464464
method="predict", msg="result variables not of equal length"

0 commit comments

Comments
 (0)