Skip to content

Commit 89be34c

Browse files
NielsRoggeNiels Rogge
andauthored
Fix SegformerForImageClassification (huggingface#15895)
* Fix reshape * Apply suggestion from code review Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
1 parent 130b987 commit 89be34c

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

src/transformers/models/segformer/modeling_segformer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,11 @@ def forward(
579579

580580
sequence_output = outputs[0]
581581

582-
# reshape last hidden states to (batch_size, height*width, hidden_size)
582+
# convert last hidden states to (batch_size, height*width, hidden_size)
583583
batch_size = sequence_output.shape[0]
584+
if self.config.reshape_last_stage:
585+
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
586+
sequence_output = sequence_output.permute(0, 2, 3, 1)
584587
sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1])
585588

586589
# global average pooling
@@ -660,10 +663,19 @@ def __init__(self, config):
660663
self.dropout = nn.Dropout(config.classifier_dropout_prob)
661664
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
662665

666+
self.config = config
667+
663668
def forward(self, encoder_hidden_states):
664-
batch_size, _, _, _ = encoder_hidden_states[-1].shape
669+
batch_size = encoder_hidden_states[-1].shape[0]
670+
665671
all_hidden_states = ()
666672
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
673+
if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
674+
height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
675+
encoder_hidden_state = (
676+
encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
677+
)
678+
667679
# unify channel dimension
668680
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
669681
encoder_hidden_state = mlp(encoder_hidden_state)

0 commit comments

Comments
 (0)