Skip to content

Commit 0cc5ab1

Browse files
julien-csgugger
andauthored
Improve bert-japanese tokenizer handling (huggingface#8659)
* Make ci fail * Try to make tests actually run? * CI finally failing? * Fix CI * Revert "Fix CI" This reverts commit ca7923b. * Ooops wrong one * one more try * Ok ok let's move this elsewhere * Alternative to globals() (huggingface#8667) * Alternative to globals() * Error is raised later so return None * Sentencepiece not installed make some tokenizers None * Apply Lysandre wisdom * Slightly clearer comment? cc @sgugger Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent eec7661 commit 0cc5ab1

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ jobs:
221221
run_tests_custom_tokenizers:
222222
working_directory: ~/transformers
223223
docker:
224-
- image: circleci/python:3.6
224+
- image: circleci/python:3.7
225225
environment:
226226
RUN_CUSTOM_TOKENIZERS: yes
227227
steps:

src/transformers/models/auto/tokenization_auto.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,6 @@
185185
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
186186
(BartConfig, (BartTokenizer, BartTokenizerFast)),
187187
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
188-
(RobertaConfig, (BertweetTokenizer, None)),
189-
(RobertaConfig, (PhobertTokenizer, None)),
190188
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
191189
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
192190
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
@@ -195,7 +193,6 @@
195193
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
196194
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
197195
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
198-
(BertConfig, (HerbertTokenizer, HerbertTokenizerFast)),
199196
(BertConfig, (BertTokenizer, BertTokenizerFast)),
200197
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
201198
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
@@ -213,13 +210,34 @@
213210
]
214211
)
215212

213+
# For tokenizers which are not directly mapped from a config
214+
NO_CONFIG_TOKENIZER = [
215+
BertJapaneseTokenizer,
216+
BertweetTokenizer,
217+
HerbertTokenizer,
218+
HerbertTokenizerFast,
219+
PhobertTokenizer,
220+
]
221+
222+
216223
SLOW_TOKENIZER_MAPPING = {
217224
k: (v[0] if v[0] is not None else v[1])
218225
for k, v in TOKENIZER_MAPPING.items()
219226
if (v[0] is not None or v[1] is not None)
220227
}
221228

222229

230+
def tokenizer_class_from_name(class_name: str):
231+
all_tokenizer_classes = (
232+
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
233+
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
234+
+ NO_CONFIG_TOKENIZER
235+
)
236+
for c in all_tokenizer_classes:
237+
if c.__name__ == class_name:
238+
return c
239+
240+
223241
class AutoTokenizer:
224242
r"""
225243
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
@@ -307,17 +325,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
307325
if not isinstance(config, PretrainedConfig):
308326
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
309327

310-
if "bert-base-japanese" in str(pretrained_model_name_or_path):
311-
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
312-
313328
use_fast = kwargs.pop("use_fast", True)
314329

315330
if config.tokenizer_class is not None:
331+
tokenizer_class = None
316332
if use_fast and not config.tokenizer_class.endswith("Fast"):
317333
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
318-
else:
334+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
335+
if tokenizer_class is None:
319336
tokenizer_class_candidate = config.tokenizer_class
320-
tokenizer_class = globals().get(tokenizer_class_candidate)
337+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
338+
321339
if tokenizer_class is None:
322340
raise ValueError(
323341
"Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)

tests/test_tokenization_bert_japanese.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pickle
1919
import unittest
2020

21+
from transformers import AutoTokenizer
2122
from transformers.models.bert_japanese.tokenization_bert_japanese import (
2223
VOCAB_FILES_NAMES,
2324
BertJapaneseTokenizer,
@@ -267,3 +268,11 @@ def test_sequence_builders(self):
267268
# 2 is for "[CLS]", 3 is for "[SEP]"
268269
assert encoded_sentence == [2] + text + [3]
269270
assert encoded_pair == [2] + text + [3] + text_2 + [3]
271+
272+
273+
@custom_tokenizers
274+
class AutoTokenizerCustomTest(unittest.TestCase):
275+
def test_tokenizer_bert_japanese(self):
276+
EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
277+
tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
278+
self.assertIsInstance(tokenizer, BertJapaneseTokenizer)

0 commit comments

Comments
 (0)