File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -418,15 +418,14 @@ def test_lm_generate_xglm_with_gradient_checkpointing(self):
418
418
def test_xglm_sample (self ):
419
419
tokenizer = XGLMTokenizer .from_pretrained ("facebook/xglm-564M" )
420
420
model = XGLMForCausalLM .from_pretrained ("facebook/xglm-564M" )
421
- model .to (torch_device )
422
421
423
422
torch .manual_seed (0 )
424
423
tokenized = tokenizer ("Today is a nice day and" , return_tensors = "pt" )
425
- input_ids = tokenized .input_ids . to ( torch_device )
424
+ input_ids = tokenized .input_ids
426
425
output_ids = model .generate (input_ids , do_sample = True , num_beams = 1 )
427
426
output_str = tokenizer .decode (output_ids [0 ], skip_special_tokens = True )
428
427
429
- EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my "
428
+ EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy "
430
429
self .assertEqual (output_str , EXPECTED_OUTPUT_STR )
431
430
432
431
@slow
You can’t perform that action at this time.
0 commit comments