Skip to content

Commit 4362ee2

Browse files
1 parent 4046e66 commit 4362ee2

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/test_modeling_rag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,9 @@ def test_rag_token_generate_batch(self):
988988
torch_device
989989
)
990990

991+
if torch_device == "cuda":
992+
rag_token.half()
993+
991994
input_dict = tokenizer(
992995
self.test_data_questions,
993996
return_tensors="pt",

0 commit comments

Comments
 (0)