Skip to content

Commit 806d033

Browse files
committed
Checking for pair of inputs too.
1 parent c0fe313 commit 806d033

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

tests/test_tokenization_common.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -787,11 +787,28 @@ def test_maximum_encoding_length_pair_input(self):
787787
self.assertEqual(len(output["input_ids"][0]), model_max_length)
788788

789789
# Simple with no truncation
790-
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
791-
self.assertNotEqual(len(output["input_ids"]), model_max_length)
790+
# Reset warnings
791+
tokenizer.deprecation_warnings = {}
792+
with self.assertLogs("transformers", level="WARNING") as cm:
793+
output = tokenizer(seq_1, seq_2, padding=padding_state, truncation=False)
794+
self.assertNotEqual(len(output["input_ids"]), model_max_length)
795+
self.assertEqual(len(cm.records), 1)
796+
self.assertTrue(
797+
cm.records[0].message.startswith(
798+
"Token indices sequence length is longer than the specified maximum sequence length for this model"
799+
)
800+
)
792801

793-
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
794-
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
802+
tokenizer.deprecation_warnings = {}
803+
with self.assertLogs("transformers", level="WARNING") as cm:
804+
output = tokenizer([seq_1], [seq_2], padding=padding_state, truncation=False)
805+
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
806+
self.assertEqual(len(cm.records), 1)
807+
self.assertTrue(
808+
cm.records[0].message.startswith(
809+
"Token indices sequence length is longer than the specified maximum sequence length for this model"
810+
)
811+
)
795812

796813
truncated_first_sequence = tokenizer.encode(seq_0, add_special_tokens=False)[:-2] + tokenizer.encode(
797814
seq_1, add_special_tokens=False

0 commit comments

Comments
 (0)