Skip to content

Commit 130b987

Browse files
authored
[XGLM] run sampling test on CPU to be deterministic (huggingface#15892)
* run sampling test on CPU to be deterministic * input_ids on CPU
1 parent baab5e7 commit 130b987

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tests/xglm/test_modeling_xglm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,14 @@ def test_lm_generate_xglm_with_gradient_checkpointing(self):
418418
def test_xglm_sample(self):
419419
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
420420
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
421-
model.to(torch_device)
422421

423422
torch.manual_seed(0)
424423
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
425-
input_ids = tokenized.input_ids.to(torch_device)
424+
input_ids = tokenized.input_ids
426425
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
427426
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
428427

429-
EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my"
428+
EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy"
430429
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
431430

432431
@slow

0 commit comments

Comments
 (0)