|
185 | 185 | (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
186 | 186 | (BartConfig, (BartTokenizer, BartTokenizerFast)),
|
187 | 187 | (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
188 |
| - (RobertaConfig, (BertweetTokenizer, None)), |
189 |
| - (RobertaConfig, (PhobertTokenizer, None)), |
190 | 188 | (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
191 | 189 | (ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
|
192 | 190 | (ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
|
|
195 | 193 | (LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
|
196 | 194 | (DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
|
197 | 195 | (SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
|
198 |
| - (BertConfig, (HerbertTokenizer, HerbertTokenizerFast)), |
199 | 196 | (BertConfig, (BertTokenizer, BertTokenizerFast)),
|
200 | 197 | (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
201 | 198 | (GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
|
|
213 | 210 | ]
|
214 | 211 | )
|
215 | 212 |
|
| 213 | +# For tokenizers which are not directly mapped from a config |
| 214 | +NO_CONFIG_TOKENIZER = [ |
| 215 | + BertJapaneseTokenizer, |
| 216 | + BertweetTokenizer, |
| 217 | + HerbertTokenizer, |
| 218 | + HerbertTokenizerFast, |
| 219 | + PhobertTokenizer, |
| 220 | +] |
| 221 | + |
| 222 | + |
216 | 223 | SLOW_TOKENIZER_MAPPING = {
|
217 | 224 | k: (v[0] if v[0] is not None else v[1])
|
218 | 225 | for k, v in TOKENIZER_MAPPING.items()
|
219 | 226 | if (v[0] is not None or v[1] is not None)
|
220 | 227 | }
|
221 | 228 |
|
222 | 229 |
|
| 230 | +def tokenizer_class_from_name(class_name: str): |
| 231 | + all_tokenizer_classes = ( |
| 232 | + [v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None] |
| 233 | + + [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None] |
| 234 | + + NO_CONFIG_TOKENIZER |
| 235 | + ) |
| 236 | + for c in all_tokenizer_classes: |
| 237 | + if c.__name__ == class_name: |
| 238 | + return c |
| 239 | + |
| 240 | + |
223 | 241 | class AutoTokenizer:
|
224 | 242 | r"""
|
225 | 243 | This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
|
@@ -307,17 +325,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
307 | 325 | if not isinstance(config, PretrainedConfig):
|
308 | 326 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
309 | 327 |
|
310 |
| - if "bert-base-japanese" in str(pretrained_model_name_or_path): |
311 |
| - return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) |
312 |
| - |
313 | 328 | use_fast = kwargs.pop("use_fast", True)
|
314 | 329 |
|
315 | 330 | if config.tokenizer_class is not None:
|
| 331 | + tokenizer_class = None |
316 | 332 | if use_fast and not config.tokenizer_class.endswith("Fast"):
|
317 | 333 | tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
318 |
| - else: |
| 334 | + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) |
| 335 | + if tokenizer_class is None: |
319 | 336 | tokenizer_class_candidate = config.tokenizer_class
|
320 |
| - tokenizer_class = globals().get(tokenizer_class_candidate) |
| 337 | + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) |
| 338 | + |
321 | 339 | if tokenizer_class is None:
|
322 | 340 | raise ValueError(
|
323 | 341 | "Tokenizer class {} does not exist or is not currently imported.".format(tokenizer_class_candidate)
|
|
0 commit comments