1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch .nn as nn
5
+ import torch .nn .functional as F
6
+
7
+ from .config import use_fused_attn
8
+ from .mlp import Mlp
9
+ from .weight_init import trunc_normal_tf_
10
+
11
+
12
+ class AttentionPoolLatent (nn .Module ):
13
+ """ Attention pooling w/ latent query
14
+ """
15
+ fused_attn : torch .jit .Final [bool ]
16
+
17
+ def __init__ (
18
+ self ,
19
+ in_features : int ,
20
+ out_features : int = None ,
21
+ embed_dim : int = None ,
22
+ num_heads : int = 8 ,
23
+ mlp_ratio : float = 4.0 ,
24
+ qkv_bias : bool = True ,
25
+ qk_norm : bool = False ,
26
+ latent_len : int = 1 ,
27
+ latent_dim : int = None ,
28
+ pos_embed : str = '' ,
29
+ pool_type : str = 'token' ,
30
+ norm_layer : Optional [nn .Module ] = None ,
31
+ drop : float = 0.0 ,
32
+ ):
33
+ super ().__init__ ()
34
+ embed_dim = embed_dim or in_features
35
+ out_features = out_features or in_features
36
+ assert embed_dim % num_heads == 0
37
+ self .num_heads = num_heads
38
+ self .head_dim = embed_dim // num_heads
39
+ self .scale = self .head_dim ** - 0.5
40
+ self .pool = pool_type
41
+ self .fused_attn = use_fused_attn ()
42
+
43
+ if pos_embed == 'abs' :
44
+ spatial_len = self .feat_size
45
+ self .pos_embed = nn .Parameter (torch .zeros (spatial_len , in_features ))
46
+ else :
47
+ self .pos_embed = None
48
+
49
+ self .latent_dim = latent_dim or embed_dim
50
+ self .latent_len = latent_len
51
+ self .latent = nn .Parameter (torch .zeros (1 , self .latent_len , embed_dim ))
52
+
53
+ self .q = nn .Linear (embed_dim , embed_dim , bias = qkv_bias )
54
+ self .kv = nn .Linear (embed_dim , embed_dim * 2 , bias = qkv_bias )
55
+ self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
56
+ self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
57
+ self .proj = nn .Linear (embed_dim , embed_dim )
58
+ self .proj_drop = nn .Dropout (drop )
59
+
60
+ self .norm = norm_layer (out_features ) if norm_layer is not None else nn .Identity ()
61
+ self .mlp = Mlp (embed_dim , int (embed_dim * mlp_ratio ))
62
+
63
+ self .init_weights ()
64
+
65
+ def init_weights (self ):
66
+ if self .pos_embed is not None :
67
+ trunc_normal_tf_ (self .pos_embed , std = self .pos_embed .shape [1 ] ** - 0.5 )
68
+ trunc_normal_tf_ (self .latent , std = self .latent_dim ** - 0.5 )
69
+
70
+ def forward (self , x ):
71
+ B , N , C = x .shape
72
+
73
+ if self .pos_embed is not None :
74
+ # FIXME interpolate
75
+ x = x + self .pos_embed .unsqueeze (0 ).to (x .dtype )
76
+
77
+ q_latent = self .latent .expand (B , - 1 , - 1 )
78
+ q = self .q (q_latent ).reshape (B , self .latent_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
79
+
80
+ kv = self .kv (x ).reshape (B , N , 2 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
81
+ k , v = kv .unbind (0 )
82
+
83
+ q , k = self .q_norm (q ), self .k_norm (k )
84
+
85
+ if self .fused_attn :
86
+ x = F .scaled_dot_product_attention (q , k , v )
87
+ else :
88
+ q = q * self .scale
89
+ attn = q @ k .transpose (- 2 , - 1 )
90
+ attn = attn .softmax (dim = - 1 )
91
+ x = attn @ v
92
+ x = x .transpose (1 , 2 ).reshape (B , self .latent_len , C )
93
+ x = self .proj (x )
94
+ x = self .proj_drop (x )
95
+
96
+ x = x + self .mlp (self .norm (x ))
97
+
98
+ # optional pool if latent seq_len > 1 and pooled output is desired
99
+ if self .pool == 'token' :
100
+ x = x [:, 0 ]
101
+ elif self .pool == 'avg' :
102
+ x = x .mean (1 )
103
+ return x
0 commit comments