Skip to content

Commit 9d9b872

Browse files
authored
The add_space_before_punct_symbol is only for TransfoXL (huggingface#5549)
1 parent d6b0b9d commit 9d9b872

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

examples/text-generation/run_generation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,14 @@ def main():
214214
if requires_preprocessing:
215215
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
216216
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
217+
218+
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
219+
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
220+
else:
221+
tokenizer_kwargs = {}
222+
217223
encoded_prompt = tokenizer.encode(
218-
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
224+
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
219225
)
220226
else:
221227
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")

0 commit comments

Comments
 (0)