@@ -355,7 +355,7 @@ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
355
355
keep_ents .append (eidx )
356
356
357
357
eidx += 1
358
- entity_encodings = self .model .ops .asarray (entity_encodings , dtype = "float32" )
358
+ entity_encodings = self .model .ops .asarray2f (entity_encodings , dtype = "float32" )
359
359
selected_encodings = sentence_encodings [keep_ents ]
360
360
361
361
# if there are no matches, short circuit
@@ -368,13 +368,12 @@ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
368
368
method = "get_loss" , msg = "gold entities do not match up"
369
369
)
370
370
raise RuntimeError (err )
371
- # TODO: fix typing issue here
372
- gradients = self .distance .get_grad (selected_encodings , entity_encodings ) # type: ignore
371
+ gradients = self .distance .get_grad (selected_encodings , entity_encodings )
373
372
# to match the input size, we need to give a zero gradient for items not in the kb
374
373
out = self .model .ops .alloc2f (* sentence_encodings .shape )
375
374
out [keep_ents ] = gradients
376
375
377
- loss = self .distance .get_loss (selected_encodings , entity_encodings ) # type: ignore
376
+ loss = self .distance .get_loss (selected_encodings , entity_encodings )
378
377
loss = loss / len (entity_encodings )
379
378
return float (loss ), out
380
379
@@ -391,74 +390,75 @@ def predict(self, docs: Iterable[Doc]) -> List[str]:
391
390
self .validate_kb ()
392
391
entity_count = 0
393
392
final_kb_ids : List [str ] = []
393
+ xp = self .model .ops .xp
394
394
if not docs :
395
395
return final_kb_ids
396
396
if isinstance (docs , Doc ):
397
397
docs = [docs ]
398
398
for i , doc in enumerate (docs ):
399
+ if len (doc ) == 0 :
400
+ continue
399
401
sentences = [s for s in doc .sents ]
400
- if len ( doc ) > 0 :
401
- # Looping through each entity (TODO: rewrite)
402
- for ent in doc . ents :
403
- sent = ent . sent
404
- sent_index = sentences . index ( sent )
405
- assert sent_index >= 0
402
+ # Looping through each entity (TODO: rewrite)
403
+ for ent in doc . ents :
404
+ sent_index = sentences . index ( ent . sent )
405
+ assert sent_index >= 0
406
+
407
+ if self . incl_context :
406
408
# get n_neighbour sentences, clipped to the length of the document
407
409
start_sentence = max (0 , sent_index - self .n_sents )
408
410
end_sentence = min (len (sentences ) - 1 , sent_index + self .n_sents )
409
411
start_token = sentences [start_sentence ].start
410
412
end_token = sentences [end_sentence ].end
411
413
sent_doc = doc [start_token :end_token ].as_doc ()
412
414
# currently, the context is the same for each entity in a sentence (should be refined)
413
- xp = self .model .ops .xp
414
- if self .incl_context :
415
- sentence_encoding = self .model .predict ([sent_doc ])[0 ]
416
- sentence_encoding_t = sentence_encoding .T
417
- sentence_norm = xp .linalg .norm (sentence_encoding_t )
418
- entity_count += 1
419
- if ent .label_ in self .labels_discard :
420
- # ignoring this entity - setting to NIL
415
+ sentence_encoding = self .model .predict ([sent_doc ])[0 ]
416
+ sentence_encoding_t = sentence_encoding .T
417
+ sentence_norm = xp .linalg .norm (sentence_encoding_t )
418
+ entity_count += 1
419
+ if ent .label_ in self .labels_discard :
420
+ # ignoring this entity - setting to NIL
421
+ final_kb_ids .append (self .NIL )
422
+ else :
423
+ candidates = list (self .get_candidates (self .kb , ent ))
424
+ if not candidates :
425
+ # no prediction possible for this entity - setting to NIL
421
426
final_kb_ids .append (self .NIL )
427
+ elif len (candidates ) == 1 :
428
+ # shortcut for efficiency reasons: take the 1 candidate
429
+ # TODO: thresholding
430
+ final_kb_ids .append (candidates [0 ].entity_ )
422
431
else :
423
- candidates = list (self .get_candidates (self .kb , ent ))
424
- if not candidates :
425
- # no prediction possible for this entity - setting to NIL
426
- final_kb_ids .append (self .NIL )
427
- elif len (candidates ) == 1 :
428
- # shortcut for efficiency reasons: take the 1 candidate
429
- # TODO: thresholding
430
- final_kb_ids .append (candidates [0 ].entity_ )
431
- else :
432
- random .shuffle (candidates )
433
- # set all prior probabilities to 0 if incl_prior=False
434
- prior_probs = xp .asarray ([c .prior_prob for c in candidates ])
435
- if not self .incl_prior :
436
- prior_probs = xp .asarray ([0.0 for _ in candidates ])
437
- scores = prior_probs
438
- # add in similarity from the context
439
- if self .incl_context :
440
- entity_encodings = xp .asarray (
441
- [c .entity_vector for c in candidates ]
442
- )
443
- entity_norm = xp .linalg .norm (entity_encodings , axis = 1 )
444
- if len (entity_encodings ) != len (prior_probs ):
445
- raise RuntimeError (
446
- Errors .E147 .format (
447
- method = "predict" ,
448
- msg = "vectors not of equal length" ,
449
- )
432
+ random .shuffle (candidates )
433
+ # set all prior probabilities to 0 if incl_prior=False
434
+ prior_probs = xp .asarray ([c .prior_prob for c in candidates ])
435
+ if not self .incl_prior :
436
+ prior_probs = xp .asarray ([0.0 for _ in candidates ])
437
+ scores = prior_probs
438
+ # add in similarity from the context
439
+ if self .incl_context :
440
+ entity_encodings = xp .asarray (
441
+ [c .entity_vector for c in candidates ]
442
+ )
443
+ entity_norm = xp .linalg .norm (entity_encodings , axis = 1 )
444
+ if len (entity_encodings ) != len (prior_probs ):
445
+ raise RuntimeError (
446
+ Errors .E147 .format (
447
+ method = "predict" ,
448
+ msg = "vectors not of equal length" ,
450
449
)
451
- # cosine similarity
452
- sims = xp .dot (entity_encodings , sentence_encoding_t ) / (
453
- sentence_norm * entity_norm
454
450
)
455
- if sims .shape != prior_probs .shape :
456
- raise ValueError (Errors .E161 )
457
- scores = prior_probs + sims - (prior_probs * sims )
458
- # TODO: thresholding
459
- best_index = scores .argmax ().item ()
460
- best_candidate = candidates [best_index ]
461
- final_kb_ids .append (best_candidate .entity_ )
451
+ # cosine similarity
452
+ sims = xp .dot (entity_encodings , sentence_encoding_t ) / (
453
+ sentence_norm * entity_norm
454
+ )
455
+ if sims .shape != prior_probs .shape :
456
+ raise ValueError (Errors .E161 )
457
+ scores = prior_probs + sims - (prior_probs * sims )
458
+ # TODO: thresholding
459
+ best_index = scores .argmax ().item ()
460
+ best_candidate = candidates [best_index ]
461
+ final_kb_ids .append (best_candidate .entity_ )
462
462
if not (len (final_kb_ids ) == entity_count ):
463
463
err = Errors .E147 .format (
464
464
method = "predict" , msg = "result variables not of equal length"
0 commit comments