Skip to content

Commit 18c8cf0

Browse files
author
Yossi Synett
authored
Fix bug in x-attentions output for roberta and harden test to catch it (huggingface#8660)
1 parent 48cc224 commit 18c8cf0

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

src/transformers/models/roberta/modeling_roberta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def forward(
814814
logits=prediction_scores,
815815
hidden_states=outputs.hidden_states,
816816
attentions=outputs.attentions,
817-
cross_attentions=outputs.attentions,
817+
cross_attentions=outputs.cross_attentions,
818818
)
819819

820820
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):

tests/test_modeling_encoder_decoder.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ def check_encoder_decoder_model_output_attentions(
300300
labels,
301301
**kwargs
302302
):
303+
# make the decoder inputs a different shape from the encoder inputs to harden the test
304+
decoder_input_ids = decoder_input_ids[:, :-1]
305+
decoder_attention_mask = decoder_attention_mask[:, :-1]
303306
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
304307
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
305308
enc_dec_model.to(torch_device)
@@ -314,9 +317,8 @@ def check_encoder_decoder_model_output_attentions(
314317
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
315318
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
316319

317-
self.assertListEqual(
318-
list(encoder_attentions[0].shape[-3:]),
319-
[config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]],
320+
self.assertEqual(
321+
encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
320322
)
321323

322324
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
@@ -327,20 +329,20 @@ def check_encoder_decoder_model_output_attentions(
327329
)
328330
self.assertEqual(len(decoder_attentions), num_decoder_layers)
329331

330-
self.assertListEqual(
331-
list(decoder_attentions[0].shape[-3:]),
332-
[decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]],
332+
self.assertEqual(
333+
decoder_attentions[0].shape[-3:],
334+
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
333335
)
334336

335337
cross_attentions = outputs_encoder_decoder["cross_attentions"]
336338
self.assertEqual(len(cross_attentions), num_decoder_layers)
337339

338-
cross_attention_input_seq_len = input_ids.shape[-1] * (
340+
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
339341
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
340342
)
341-
self.assertListEqual(
342-
list(cross_attentions[0].shape[-3:]),
343-
[decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]],
343+
self.assertEqual(
344+
cross_attentions[0].shape[-3:],
345+
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
344346
)
345347

346348
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):

0 commit comments

Comments
 (0)