Skip to content

Commit 7375758

Browse files
authored
Fix tests (huggingface#14703)
1 parent 68e53e6 commit 7375758

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

tests/test_modeling_perceiver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,8 @@ def test_inference_masked_lm(self):
860860
self.assertEqual(logits.shape, expected_shape)
861861

862862
expected_slice = torch.tensor(
863-
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]]
863+
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]],
864+
device=torch_device,
864865
)
865866

866867
self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4))
@@ -970,7 +971,7 @@ def test_inference_optical_flow(self):
970971

971972
# forward pass
972973
with torch.no_grad():
973-
outputs = model(inputs=patches)
974+
outputs = model(inputs=patches.to(torch_device))
974975
logits = outputs.logits
975976

976977
# verify logits

tests/test_tokenization_perceiver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,17 @@ def test_multibytes_char(self):
9999

100100
# decoding
101101
decoded = tokenizer.decode(encoded_ids)
102-
self.assertEqual(decoded, "<cls>Unicode €.<sep>")
102+
self.assertEqual(decoded, "[CLS]Unicode €.[SEP]")
103103

104104
encoded = tokenizer("e è é ê ë")
105105
encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5]
106106
self.assertEqual(encoded["input_ids"], encoded_ids)
107107
# decoding
108108
decoded = tokenizer.decode(encoded_ids)
109-
self.assertEqual(decoded, "<cls>e è é ê ë<sep>")
109+
self.assertEqual(decoded, "[CLS]e è é ê ë[SEP]")
110110

111111
# encode/decode, but with `encode` instead of `__call__`
112-
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "<cls>e è é ê ë<sep>")
112+
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "[CLS]e è é ê ë[SEP]")
113113

114114
def test_prepare_batch_integration(self):
115115
tokenizer = self.perceiver_tokenizer

0 commit comments

Comments
 (0)