Skip to content

Commit eaeca5e

Browse files
authored
account for NER labels with a hyphen in the name (explosion#10960)
* account for NER labels with a hyphen in the name * cleanup * fix docstring * add return type to helper method * shorter method and few more occurrences * user helper method across repo * fix circular import * partial revert to avoid circular import
1 parent 6313787 commit eaeca5e

File tree

11 files changed

+48
-21
lines changed

11 files changed

+48
-21
lines changed

spacy/cli/debug_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
1212
from ._util import import_code, debug_cli
13-
from ..training import Example
13+
from ..training import Example, remove_bilu_prefix
1414
from ..training.initialize import get_sourced_components
1515
from ..schemas import ConfigSchemaTraining
1616
from ..pipeline._parser_internals import nonproj
@@ -758,9 +758,9 @@ def _compile_gold(
758758
# "Illegal" whitespace entity
759759
data["ws_ents"] += 1
760760
if label.startswith(("B-", "U-")):
761-
combined_label = label.split("-")[1]
761+
combined_label = remove_bilu_prefix(label)
762762
data["ner"][combined_label] += 1
763-
if sent_starts[i] == True and label.startswith(("I-", "L-")):
763+
if sent_starts[i] and label.startswith(("I-", "L-")):
764764
data["boundary_cross_ents"] += 1
765765
elif label == "-":
766766
data["ner"]["-"] += 1
@@ -908,7 +908,7 @@ def _get_examples_without_label(
908908
for eg in data:
909909
if component == "ner":
910910
labels = [
911-
label.split("-")[1]
911+
remove_bilu_prefix(label)
912912
for label in eg.get_aligned_ner()
913913
if label not in ("O", "-", None)
914914
]

spacy/pipeline/_parser_internals/arc_eager.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from ...strings cimport hash_string
1010
from ...structs cimport TokenC
1111
from ...tokens.doc cimport Doc, set_children_from_heads
1212
from ...tokens.token cimport MISSING_DEP
13+
from ...training import split_bilu_label
1314
from ...training.example cimport Example
1415
from .stateclass cimport StateClass
1516
from ._state cimport StateC, ArcC
@@ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem):
687688
return self.c[name_or_id]
688689
name = name_or_id
689690
if '-' in name:
690-
move_str, label_str = name.split('-', 1)
691+
move_str, label_str = split_bilu_label(name)
691692
label = self.strings[label_str]
692693
else:
693694
move_str = name

spacy/pipeline/_parser_internals/ner.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ from ...typedefs cimport weight_t, attr_t
1313
from ...lexeme cimport Lexeme
1414
from ...attrs cimport IS_SPACE
1515
from ...structs cimport TokenC, SpanC
16+
from ...training import split_bilu_label
1617
from ...training.example cimport Example
1718
from .stateclass cimport StateClass
1819
from ._state cimport StateC
@@ -182,7 +183,7 @@ cdef class BiluoPushDown(TransitionSystem):
182183
if name == '-' or name == '' or name is None:
183184
return Transition(clas=0, move=MISSING, label=0, score=0)
184185
elif '-' in name:
185-
move_str, label_str = name.split('-', 1)
186+
move_str, label_str = split_bilu_label(name)
186187
# Deprecated, hacky way to denote 'not this entity'
187188
if label_str.startswith('!'):
188189
raise ValueError(Errors.E869.format(label=name))

spacy/pipeline/dep_parser.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from ..language import Language
1212
from ._parser_internals import nonproj
1313
from ._parser_internals.nonproj import DELIMITER
1414
from ..scorer import Scorer
15+
from ..training import remove_bilu_prefix
1516
from ..util import registry
1617

1718

@@ -314,7 +315,7 @@ cdef class DependencyParser(Parser):
314315
# Get the labels from the model by looking at the available moves
315316
for move in self.move_names:
316317
if "-" in move:
317-
label = move.split("-")[1]
318+
label = remove_bilu_prefix(move)
318319
if DELIMITER in label:
319320
label = label.split(DELIMITER)[1]
320321
labels.add(label)

spacy/pipeline/ner.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ from thinc.api import Model, Config
66
from ._parser_internals.transition_system import TransitionSystem
77
from .transition_parser cimport Parser
88
from ._parser_internals.ner cimport BiluoPushDown
9-
109
from ..language import Language
1110
from ..scorer import get_ner_prf, PRFScore
1211
from ..util import registry
12+
from ..training import remove_bilu_prefix
1313

1414

1515
default_model_config = """
@@ -242,7 +242,7 @@ cdef class EntityRecognizer(Parser):
242242
def labels(self):
243243
# Get the labels from the model by looking at the available moves, e.g.
244244
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
245-
labels = set(move.split("-")[1] for move in self.move_names
245+
labels = set(remove_bilu_prefix(move) for move in self.move_names
246246
if move[0] in ("B", "I", "L", "U"))
247247
return tuple(sorted(labels))
248248

spacy/tests/parser/test_ner.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from spacy.language import Language
1111
from spacy.lookups import Lookups
1212
from spacy.pipeline._parser_internals.ner import BiluoPushDown
13-
from spacy.training import Example, iob_to_biluo
13+
from spacy.training import Example, iob_to_biluo, split_bilu_label
1414
from spacy.tokens import Doc, Span
1515
from spacy.vocab import Vocab
1616
import logging
@@ -110,6 +110,9 @@ def test_issue2385():
110110
# maintain support for iob2 format
111111
tags3 = ("B-PERSON", "I-PERSON", "B-PERSON")
112112
assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"]
113+
# ensure it works with hyphens in the name
114+
tags4 = ("B-MULTI-PERSON", "I-MULTI-PERSON", "B-MULTI-PERSON")
115+
assert iob_to_biluo(tags4) == ["B-MULTI-PERSON", "L-MULTI-PERSON", "U-MULTI-PERSON"]
113116

114117

115118
@pytest.mark.issue(2800)
@@ -154,6 +157,19 @@ def test_issue3209():
154157
assert ner2.move_names == move_names
155158

156159

160+
def test_labels_from_BILUO():
161+
"""Test that labels are inferred correctly when there's a - in label.
162+
"""
163+
nlp = English()
164+
ner = nlp.add_pipe("ner")
165+
ner.add_label("LARGE-ANIMAL")
166+
nlp.initialize()
167+
move_names = ["O", "B-LARGE-ANIMAL", "I-LARGE-ANIMAL", "L-LARGE-ANIMAL", "U-LARGE-ANIMAL"]
168+
labels = {"LARGE-ANIMAL"}
169+
assert ner.move_names == move_names
170+
assert set(ner.labels) == labels
171+
172+
157173
@pytest.mark.issue(4267)
158174
def test_issue4267():
159175
"""Test that running an entity_ruler after ner gives consistent results"""
@@ -298,7 +314,7 @@ def test_oracle_moves_missing_B(en_vocab):
298314
elif tag == "O":
299315
moves.add_action(move_types.index("O"), "")
300316
else:
301-
action, label = tag.split("-")
317+
action, label = split_bilu_label(tag)
302318
moves.add_action(move_types.index("B"), label)
303319
moves.add_action(move_types.index("I"), label)
304320
moves.add_action(move_types.index("L"), label)
@@ -324,7 +340,7 @@ def test_oracle_moves_whitespace(en_vocab):
324340
elif tag == "O":
325341
moves.add_action(move_types.index("O"), "")
326342
else:
327-
action, label = tag.split("-")
343+
action, label = split_bilu_label(tag)
328344
moves.add_action(move_types.index(action), label)
329345
moves.get_oracle_sequence(example)
330346

spacy/tests/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from spacy.tokens import Doc
66
from spacy.vocab import Vocab
77
from spacy.util import make_tempdir # noqa: F401
8+
from spacy.training import split_bilu_label
89
from thinc.api import get_current_ops
910

1011

@@ -40,7 +41,7 @@ def apply_transition_sequence(parser, doc, sequence):
4041
desired state."""
4142
for action_name in sequence:
4243
if "-" in action_name:
43-
move, label = action_name.split("-")
44+
move, label = split_bilu_label(action_name)
4445
parser.add_label(label)
4546
with parser.step_through(doc) as stepwise:
4647
for transition in sequence:

spacy/training/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
66
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
77
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
8+
from .iob_utils import split_bilu_label, remove_bilu_prefix # noqa: F401
89
from .gold_io import docs_to_json, read_json_file # noqa: F401
910
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
1011
from .loggers import console_logger # noqa: F401

spacy/training/augment.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import random
44
import itertools
55
from functools import partial
6-
from pydantic import BaseModel, StrictStr
76

87
from ..util import registry
98
from .example import Example
9+
from .iob_utils import split_bilu_label
1010

1111
if TYPE_CHECKING:
1212
from ..language import Language # noqa: F401
@@ -278,10 +278,8 @@ def make_whitespace_variant(
278278
ent_prev = doc_dict["entities"][position - 1]
279279
ent_next = doc_dict["entities"][position]
280280
if "-" in ent_prev and "-" in ent_next:
281-
ent_iob_prev = ent_prev.split("-")[0]
282-
ent_type_prev = ent_prev.split("-", 1)[1]
283-
ent_iob_next = ent_next.split("-")[0]
284-
ent_type_next = ent_next.split("-", 1)[1]
281+
ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
282+
ent_iob_next, ent_type_next = split_bilu_label(ent_next)
285283
if (
286284
ent_iob_prev in ("B", "I")
287285
and ent_iob_next in ("I", "L")

spacy/training/example.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ from ..tokens.span import Span
99
from ..attrs import IDS
1010
from .alignment import Alignment
1111
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
12-
from .iob_utils import biluo_tags_to_spans
12+
from .iob_utils import biluo_tags_to_spans, remove_bilu_prefix
1313
from ..errors import Errors, Warnings
1414
from ..pipeline._parser_internals import nonproj
1515
from ..tokens.token cimport MISSING_DEP
@@ -519,7 +519,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
519519
else:
520520
ent_iobs.append(iob_tag.split("-")[0])
521521
if iob_tag.startswith("I") or iob_tag.startswith("B"):
522-
ent_types.append(iob_tag.split("-", 1)[1])
522+
ent_types.append(remove_bilu_prefix(iob_tag))
523523
else:
524524
ent_types.append("")
525525
return ent_iobs, ent_types

spacy/training/iob_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Tuple, Iterable, Union, Iterator
1+
from typing import List, Dict, Tuple, Iterable, Union, Iterator, cast
22
import warnings
33

44
from ..errors import Errors, Warnings
@@ -218,6 +218,14 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
218218
return entities
219219

220220

221+
def split_bilu_label(label: str) -> Tuple[str, str]:
222+
return cast(Tuple[str, str], label.split("-", 1))
223+
224+
225+
def remove_bilu_prefix(label: str) -> str:
226+
return label.split("-", 1)[1]
227+
228+
221229
# Fallbacks to make backwards-compat easier
222230
offsets_from_biluo_tags = biluo_tags_to_offsets
223231
spans_from_biluo_tags = biluo_tags_to_spans

0 commit comments

Comments
 (0)