@@ -579,8 +579,11 @@ def forward(
579
579
580
580
sequence_output = outputs [0 ]
581
581
582
- # reshape last hidden states to (batch_size, height*width, hidden_size)
582
+ # convert last hidden states to (batch_size, height*width, hidden_size)
583
583
batch_size = sequence_output .shape [0 ]
584
+ if self .config .reshape_last_stage :
585
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
586
+ sequence_output = sequence_output .permute (0 , 2 , 3 , 1 )
584
587
sequence_output = sequence_output .reshape (batch_size , - 1 , self .config .hidden_sizes [- 1 ])
585
588
586
589
# global average pooling
@@ -660,10 +663,19 @@ def __init__(self, config):
660
663
self .dropout = nn .Dropout (config .classifier_dropout_prob )
661
664
self .classifier = nn .Conv2d (config .decoder_hidden_size , config .num_labels , kernel_size = 1 )
662
665
666
+ self .config = config
667
+
663
668
def forward (self , encoder_hidden_states ):
664
- batch_size , _ , _ , _ = encoder_hidden_states [- 1 ].shape
669
+ batch_size = encoder_hidden_states [- 1 ].shape [0 ]
670
+
665
671
all_hidden_states = ()
666
672
for encoder_hidden_state , mlp in zip (encoder_hidden_states , self .linear_c ):
673
+ if self .config .reshape_last_stage is False and encoder_hidden_state .ndim == 3 :
674
+ height = width = int (math .sqrt (encoder_hidden_state .shape [- 1 ]))
675
+ encoder_hidden_state = (
676
+ encoder_hidden_state .reshape (batch_size , height , width , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
677
+ )
678
+
667
679
# unify channel dimension
668
680
height , width = encoder_hidden_state .shape [2 ], encoder_hidden_state .shape [3 ]
669
681
encoder_hidden_state = mlp (encoder_hidden_state )
0 commit comments