Skip to content

Commit 2eea97d

Browse files
authored
Update vision_transformer.py
1 parent ef72c3c commit 2eea97d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

timm/models/vision_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,9 @@ def forward_head(self, x, pre_logits: bool = False):
682682
return x if pre_logits else self.head(x)
683683

684684
def forward(self, x):
685-
x = self.forward_features(x)
686-
x = self.forward_head(x)
687-
return x
685+
tokens = self.forward_features(x)
686+
cls_token = self.forward_head(tokens)
687+
return cls_token, tokens
688688

689689

690690
def init_weights_vit_timm(module: nn.Module, name: str = ''):

0 commit comments

Comments
 (0)