29
29
import torch
30
30
31
31
from transformers import (
32
+ AutoTokenizer ,
32
33
GPT2TokenizerFast ,
33
34
GPTBigCodeForCausalLM ,
34
35
GPTBigCodeForSequenceClassification ,
@@ -510,7 +511,7 @@ def test_generate_simple(self):
510
511
output_sequence = model .generate (input_ids )
511
512
output_sentence = tokenizer .decode (output_sequence [0 ], skip_special_tokens = True )
512
513
513
- expected_output = """ def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_"""
514
+ expected_output = ' def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_world_with_args(name' # fmt: skip
514
515
self .assertEqual (output_sentence , expected_output )
515
516
516
517
def test_generate_batched (self ):
@@ -527,11 +528,27 @@ def test_generate_batched(self):
527
528
outputs = tokenizer .batch_decode (outputs , skip_special_tokens = True )
528
529
529
530
expected_output = [
530
- 'def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_ ' ,
531
- 'def say_hello():\n print("Hello, World!")\n \n \n say_hello()' ,
531
+ 'def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_world_with_args(name ' ,
532
+ 'def say_hello():\n print("Hello, World!")\n \n \n say_hello()\n ' ,
532
533
]
533
534
self .assertListEqual (outputs , expected_output )
534
535
536
+ def test_newline_regression (self ):
537
+ """Added to prevent regressions regarding attention (scaling) indicated by excessive newlines"""
538
+ tokenizer = AutoTokenizer .from_pretrained ("bigcode/tiny_starcoder_py" )
539
+ model = GPTBigCodeForCausalLM .from_pretrained ("bigcode/tiny_starcoder_py" ).to (torch_device )
540
+
541
+ input_ids = tokenizer (
542
+ "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n " ,
543
+ return_tensors = "pt" ,
544
+ ).input_ids .to (torch_device )
545
+
546
+ output_sequence = model .generate (input_ids , max_new_tokens = 20 , do_sample = False )
547
+ output_sentence = tokenizer .decode (output_sequence [0 ], skip_special_tokens = True )
548
+
549
+ expected_output = 'Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n \n The impact of the COVID-19 pandemic on global economic structures and future business' # fmt: skip
550
+ self .assertEqual (output_sentence , expected_output )
551
+
535
552
536
553
@require_torch
537
554
class GPTBigCodeMQATest (unittest .TestCase ):
0 commit comments