@@ -80,15 +80,13 @@ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=N
80
80
(img_mod1 , img_mod2 ), (txt_mod1 , txt_mod2 ) = vec
81
81
82
82
# prepare image for attention
83
- img_modulated = self .img_norm1 (img )
84
- img_modulated = (1 + img_mod1 .scale ) * img_modulated + img_mod1 .shift
83
+ img_modulated = torch .addcmul (img_mod1 .shift , 1 + img_mod1 .scale , self .img_norm1 (img ))
85
84
img_qkv = self .img_attn .qkv (img_modulated )
86
85
img_q , img_k , img_v = img_qkv .view (img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
87
86
img_q , img_k = self .img_attn .norm (img_q , img_k , img_v )
88
87
89
88
# prepare txt for attention
90
- txt_modulated = self .txt_norm1 (txt )
91
- txt_modulated = (1 + txt_mod1 .scale ) * txt_modulated + txt_mod1 .shift
89
+ txt_modulated = torch .addcmul (txt_mod1 .shift , 1 + txt_mod1 .scale , self .txt_norm1 (txt ))
92
90
txt_qkv = self .txt_attn .qkv (txt_modulated )
93
91
txt_q , txt_k , txt_v = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
94
92
txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
@@ -102,12 +100,12 @@ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=N
102
100
txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ] :]
103
101
104
102
# calculate the img bloks
105
- img = img + img_mod1 .gate * self .img_attn .proj (img_attn )
106
- img = img + img_mod2 .gate * self .img_mlp (( 1 + img_mod2 .scale ) * self .img_norm2 (img ) + img_mod2 . shift )
103
+ img . addcmul_ ( img_mod1 .gate , self .img_attn .proj (img_attn ) )
104
+ img . addcmul_ ( img_mod2 .gate , self .img_mlp (torch . addcmul ( img_mod2 . shift , 1 + img_mod2 .scale , self .img_norm2 (img ))) )
107
105
108
106
# calculate the txt bloks
109
- txt += txt_mod1 .gate * self .txt_attn .proj (txt_attn )
110
- txt += txt_mod2 .gate * self .txt_mlp (( 1 + txt_mod2 .scale ) * self .txt_norm2 (txt ) + txt_mod2 . shift )
107
+ txt . addcmul_ ( txt_mod1 .gate , self .txt_attn .proj (txt_attn ) )
108
+ txt . addcmul_ ( txt_mod2 .gate , self .txt_mlp (torch . addcmul ( txt_mod2 . shift , 1 + txt_mod2 .scale , self .txt_norm2 (txt ))) )
111
109
112
110
if txt .dtype == torch .float16 :
113
111
txt = torch .nan_to_num (txt , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
@@ -152,7 +150,7 @@ def __init__(
152
150
153
151
def forward (self , x : Tensor , pe : Tensor , vec : Tensor , attn_mask = None ) -> Tensor :
154
152
mod = vec
155
- x_mod = ( 1 + mod .scale ) * self .pre_norm (x ) + mod . shift
153
+ x_mod = torch . addcmul ( mod . shift , 1 + mod .scale , self .pre_norm (x ))
156
154
qkv , mlp = torch .split (self .linear1 (x_mod ), [3 * self .hidden_size , self .mlp_hidden_dim ], dim = - 1 )
157
155
158
156
q , k , v = qkv .view (qkv .shape [0 ], qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
@@ -162,7 +160,7 @@ def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
162
160
attn = attention (q , k , v , pe = pe , mask = attn_mask )
163
161
# compute activation in mlp stream, cat again and run second linear layer
164
162
output = self .linear2 (torch .cat ((attn , self .mlp_act (mlp )), 2 ))
165
- x += mod .gate * output
163
+ x . addcmul_ ( mod .gate , output )
166
164
if x .dtype == torch .float16 :
167
165
x = torch .nan_to_num (x , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
168
166
return x
@@ -178,6 +176,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
178
176
shift , scale = vec
179
177
shift = shift .squeeze (1 )
180
178
scale = scale .squeeze (1 )
181
- x = ( 1 + scale [:, None , :]) * self . norm_final ( x ) + shift [:, None , :]
179
+ x = torch . addcmul ( shift [:, None , :], 1 + scale [:, None , :], self . norm_final ( x ))
182
180
x = self .linear (x )
183
181
return x
0 commit comments