@@ -300,6 +300,9 @@ def check_encoder_decoder_model_output_attentions(
300
300
labels ,
301
301
** kwargs
302
302
):
303
+ # make the decoder inputs a different shape from the encoder inputs to harden the test
304
+ decoder_input_ids = decoder_input_ids [:, :- 1 ]
305
+ decoder_attention_mask = decoder_attention_mask [:, :- 1 ]
303
306
encoder_model , decoder_model = self .get_encoder_decoder_model (config , decoder_config )
304
307
enc_dec_model = EncoderDecoderModel (encoder = encoder_model , decoder = decoder_model )
305
308
enc_dec_model .to (torch_device )
@@ -314,9 +317,8 @@ def check_encoder_decoder_model_output_attentions(
314
317
encoder_attentions = outputs_encoder_decoder ["encoder_attentions" ]
315
318
self .assertEqual (len (encoder_attentions ), config .num_hidden_layers )
316
319
317
- self .assertListEqual (
318
- list (encoder_attentions [0 ].shape [- 3 :]),
319
- [config .num_attention_heads , input_ids .shape [- 1 ], input_ids .shape [- 1 ]],
320
+ self .assertEqual (
321
+ encoder_attentions [0 ].shape [- 3 :], (config .num_attention_heads , input_ids .shape [- 1 ], input_ids .shape [- 1 ])
320
322
)
321
323
322
324
decoder_attentions = outputs_encoder_decoder ["decoder_attentions" ]
@@ -327,20 +329,20 @@ def check_encoder_decoder_model_output_attentions(
327
329
)
328
330
self .assertEqual (len (decoder_attentions ), num_decoder_layers )
329
331
330
- self .assertListEqual (
331
- list ( decoder_attentions [0 ].shape [- 3 :]) ,
332
- [ decoder_config .num_attention_heads , decoder_input_ids .shape [- 1 ], decoder_input_ids .shape [- 1 ]] ,
332
+ self .assertEqual (
333
+ decoder_attentions [0 ].shape [- 3 :],
334
+ ( decoder_config .num_attention_heads , decoder_input_ids .shape [- 1 ], decoder_input_ids .shape [- 1 ]) ,
333
335
)
334
336
335
337
cross_attentions = outputs_encoder_decoder ["cross_attentions" ]
336
338
self .assertEqual (len (cross_attentions ), num_decoder_layers )
337
339
338
- cross_attention_input_seq_len = input_ids .shape [- 1 ] * (
340
+ cross_attention_input_seq_len = decoder_input_ids .shape [- 1 ] * (
339
341
1 + (decoder_config .ngram if hasattr (decoder_config , "ngram" ) else 0 )
340
342
)
341
- self .assertListEqual (
342
- list ( cross_attentions [0 ].shape [- 3 :]) ,
343
- [ decoder_config .num_attention_heads , cross_attention_input_seq_len , decoder_input_ids .shape [- 1 ]] ,
343
+ self .assertEqual (
344
+ cross_attentions [0 ].shape [- 3 :],
345
+ ( decoder_config .num_attention_heads , cross_attention_input_seq_len , input_ids .shape [- 1 ]) ,
344
346
)
345
347
346
348
def check_encoder_decoder_model_generate (self , input_ids , config , decoder_config , ** kwargs ):
0 commit comments