Skip to content

Commit d6b0b9d

Browse files
authored
GPT2 tokenizer should not output token type IDs (huggingface#5546)
* GPT2 tokenizer should not output token type IDs * Same for OpenAIGPT
1 parent 7833b21 commit d6b0b9d

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/transformers/tokenization_gpt2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
137137
vocab_files_names = VOCAB_FILES_NAMES
138138
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
139139
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
140+
model_input_names = ["attention_mask"]
140141

141142
def __init__(
142143
self,
@@ -330,6 +331,7 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
330331
vocab_files_names = VOCAB_FILES_NAMES
331332
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
332333
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
334+
model_input_names = ["attention_mask"]
333335

334336
def __init__(
335337
self,

src/transformers/tokenization_openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
9696
vocab_files_names = VOCAB_FILES_NAMES
9797
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
9898
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
99+
model_input_names = ["attention_mask"]
99100

100101
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
101102
super().__init__(unk_token=unk_token, **kwargs)
@@ -261,6 +262,7 @@ class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
261262
vocab_files_names = VOCAB_FILES_NAMES
262263
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
263264
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
265+
model_input_names = ["attention_mask"]
264266

265267
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
266268
kwargs.setdefault("unk_token", unk_token)

0 commit comments

Comments
 (0)