Skip to content

Commit 5787e4c

Browse files
authored
Various tokenizers fixes (huggingface#5558)
* BertTokenizerFast - Do not specify strip_accents by default * Bump tokenizers to new version * Add test for AddedToken serialization
1 parent 21f28c3 commit 5787e4c

File tree

4 files changed

+42
-25
lines changed

4 files changed

+42
-25
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
packages=find_packages("src"),
115115
install_requires=[
116116
"numpy",
117-
"tokenizers == 0.8.0-rc4",
117+
"tokenizers == 0.8.1.rc1",
118118
# dataclasses for Python versions that don't have it
119119
"dataclasses;python_version<'3.7'",
120120
# utilities from PyPA to e.g. compare versions

src/transformers/tokenization_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def __init__(
606606
mask_token="[MASK]",
607607
clean_text=True,
608608
tokenize_chinese_chars=True,
609-
strip_accents=True,
609+
strip_accents=None,
610610
wordpieces_prefix="##",
611611
**kwargs
612612
):

tests/test_tokenization_common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
2626
from transformers.testing_utils import require_tf, require_torch, slow
27+
from transformers.tokenization_utils import AddedToken
2728

2829

2930
if TYPE_CHECKING:
@@ -233,6 +234,12 @@ def test_pickle_tokenizer(self):
233234

234235
self.assertListEqual(subwords, subwords_loaded)
235236

237+
def test_pickle_added_tokens(self):
238+
tok1 = AddedToken("<s>", rstrip=True, lstrip=True, normalized=False, single_word=True)
239+
tok2 = pickle.loads(pickle.dumps(tok1))
240+
241+
self.assertEqual(tok1.__getstate__(), tok2.__getstate__())
242+
236243
def test_added_tokens_do_lower_case(self):
237244
# TODO(thom) activate fast tokenizer tests once Rust tokenizers accepts white spaces in added tokens
238245
tokenizers = self.get_tokenizers(fast=False, do_lower_case=True)

tests/test_tokenization_fast.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name)
9191
self.assert_padding(tokenizer_r, tokenizer_p)
9292
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
9393
self.assert_prepare_for_model(tokenizer_r, tokenizer_p)
94-
# TODO: enable for v3.0.0
95-
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
9694

9795
def fast_only(self, tokenizer_r):
9896
# Ensure None raise an error
@@ -748,29 +746,41 @@ def assert_offsets_with_special_characters(self, tokenizer_r):
748746
add_special_tokens=True,
749747
)
750748

751-
expected_results = [
752-
((0, 1), "A"),
753-
((1, 2), ","),
754-
((3, 8), "naive"), # BERT normalizes this away
755-
# Append MASK here after lower-casing
756-
((16, 21), "Allen"),
757-
((22, 24), "##NL"),
758-
((24, 25), "##P"),
759-
((26, 34), "sentence"),
760-
((35, 36), "."),
761-
]
762-
763-
# Check if the tokenizer is uncased
764-
if tokenizer_r.init_kwargs.get("do_lower_case"):
765-
expected_results = [(offset, token.lower()) for (offset, token) in expected_results]
766-
767-
# Append the special tokens
768-
expected_results.insert(3, ((9, 15), "[MASK]"))
769-
expected_results.insert(0, (None, "[CLS]"))
770-
expected_results.append((None, "[SEP]"))
749+
do_lower_case = tokenizer_r.init_kwargs.get("do_lower_case")
750+
expected_results = (
751+
[
752+
((0, 0), "[CLS]"),
753+
((0, 1), "A"),
754+
((1, 2), ","),
755+
((3, 5), "na"),
756+
((5, 6), "##ï"),
757+
((6, 8), "##ve"),
758+
((9, 15), "[MASK]"),
759+
((16, 21), "Allen"),
760+
((21, 23), "##NL"),
761+
((23, 24), "##P"),
762+
((25, 33), "sentence"),
763+
((33, 34), "."),
764+
((0, 0), "[SEP]"),
765+
]
766+
if not do_lower_case
767+
else [
768+
((0, 0), "[CLS]"),
769+
((0, 1), "a"),
770+
((1, 2), ","),
771+
((3, 8), "naive"),
772+
((9, 15), "[MASK]"),
773+
((16, 21), "allen"),
774+
((21, 23), "##nl"),
775+
((23, 24), "##p"),
776+
((25, 33), "sentence"),
777+
((33, 34), "."),
778+
((0, 0), "[SEP]"),
779+
]
780+
)
771781

772782
self.assertEqual([e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"]))
773-
# self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
783+
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
774784

775785

776786
class RobertaFastTokenizerTest(CommonFastTokenizerTest):

0 commit comments

Comments
 (0)