Skip to content

Commit 9964243

Browse files
authored
Make the Tagger neg_prefix configurable (explosion#9802)
1 parent b56b9e7 commit 9964243

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

spacy/pipeline/tagger.pyx

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
4545
@Language.factory(
4646
"tagger",
4747
assigns=["token.tag"],
48-
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
48+
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
4949
default_score_weights={"tag_acc": 1.0},
5050
)
5151
def make_tagger(
@@ -54,6 +54,7 @@ def make_tagger(
5454
model: Model,
5555
overwrite: bool,
5656
scorer: Optional[Callable],
57+
neg_prefix: str,
5758
):
5859
"""Construct a part-of-speech tagger component.
5960
@@ -62,7 +63,7 @@ def make_tagger(
6263
in size, and be normalized as probabilities (all scores between 0 and 1,
6364
with the rows summing to 1).
6465
"""
65-
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
66+
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
6667

6768

6869
def tagger_score(examples, **kwargs):
@@ -87,6 +88,7 @@ class Tagger(TrainablePipe):
8788
*,
8889
overwrite=BACKWARD_OVERWRITE,
8990
scorer=tagger_score,
91+
neg_prefix="!",
9092
):
9193
"""Initialize a part-of-speech tagger.
9294
@@ -103,7 +105,7 @@ class Tagger(TrainablePipe):
103105
self.model = model
104106
self.name = name
105107
self._rehearsal_model = None
106-
cfg = {"labels": [], "overwrite": overwrite}
108+
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
107109
self.cfg = dict(sorted(cfg.items()))
108110
self.scorer = scorer
109111

@@ -253,7 +255,7 @@ class Tagger(TrainablePipe):
253255
DOCS: https://spacy.io/api/tagger#get_loss
254256
"""
255257
validate_examples(examples, "Tagger.get_loss")
256-
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix="!")
258+
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"])
257259
# Convert empty tag "" to missing value None so that both misaligned
258260
# tokens and tokens with missing annotation have the default missing
259261
# value None.

0 commit comments

Comments
 (0)