Skip to content

Commit e29919f

Browse files
authored
[GPT Big Code] Fix attention scaling (#40041)
* fix * update integration tests * fmt * add regression test
1 parent eca7030 commit e29919f

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
144144
)
145145

146146
self.scale_attn_weights = config.scale_attn_weights
147-
self.scaling = self.head_dim**0.5 if config.scale_attn_weights else 1.0
147+
self.scaling = self.head_dim**-0.5 if config.scale_attn_weights else 1.0
148148
self.is_cross_attention = is_cross_attention
149149

150150
self.layer_idx = layer_idx

tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import torch
3030

3131
from transformers import (
32+
AutoTokenizer,
3233
GPT2TokenizerFast,
3334
GPTBigCodeForCausalLM,
3435
GPTBigCodeForSequenceClassification,
@@ -510,7 +511,7 @@ def test_generate_simple(self):
510511
output_sequence = model.generate(input_ids)
511512
output_sentence = tokenizer.decode(output_sequence[0], skip_special_tokens=True)
512513

513-
expected_output = """def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_"""
514+
expected_output = 'def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_world_with_args(name' # fmt: skip
514515
self.assertEqual(output_sentence, expected_output)
515516

516517
def test_generate_batched(self):
@@ -527,11 +528,27 @@ def test_generate_batched(self):
527528
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
528529

529530
expected_output = [
530-
'def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_',
531-
'def say_hello():\n print("Hello, World!")\n\n\nsay_hello()',
531+
'def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_world_with_args(name',
532+
'def say_hello():\n print("Hello, World!")\n\n\nsay_hello()\n',
532533
]
533534
self.assertListEqual(outputs, expected_output)
534535

536+
def test_newline_regression(self):
537+
"""Added to prevent regressions regarding attention (scaling) indicated by excessive newlines"""
538+
tokenizer = AutoTokenizer.from_pretrained("bigcode/tiny_starcoder_py")
539+
model = GPTBigCodeForCausalLM.from_pretrained("bigcode/tiny_starcoder_py").to(torch_device)
540+
541+
input_ids = tokenizer(
542+
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n",
543+
return_tensors="pt",
544+
).input_ids.to(torch_device)
545+
546+
output_sequence = model.generate(input_ids, max_new_tokens=20, do_sample=False)
547+
output_sentence = tokenizer.decode(output_sequence[0], skip_special_tokens=True)
548+
549+
expected_output = 'Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n\nThe impact of the COVID-19 pandemic on global economic structures and future business' # fmt: skip
550+
self.assertEqual(output_sentence, expected_output)
551+
535552

536553
@require_torch
537554
class GPTBigCodeMQATest(unittest.TestCase):

0 commit comments

Comments
 (0)