Skip to content

Commit 08b7cc7

Browse files
authored
use fused multiply-add pointwise ops in chroma (comfyanonymous#8279)
1 parent 6c319cb commit 08b7cc7

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

comfy/ldm/chroma/layers.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,13 @@ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=N
8080
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
8181

8282
# prepare image for attention
83-
img_modulated = self.img_norm1(img)
84-
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
83+
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
8584
img_qkv = self.img_attn.qkv(img_modulated)
8685
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
8786
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
8887

8988
# prepare txt for attention
90-
txt_modulated = self.txt_norm1(txt)
91-
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
89+
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
9290
txt_qkv = self.txt_attn.qkv(txt_modulated)
9391
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
9492
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -102,12 +100,12 @@ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=N
102100
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
103101

104102
# calculate the img bloks
105-
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
106-
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
103+
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
104+
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
107105

108106
# calculate the txt bloks
109-
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
110-
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
107+
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
108+
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
111109

112110
if txt.dtype == torch.float16:
113111
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -152,7 +150,7 @@ def __init__(
152150

153151
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
154152
mod = vec
155-
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
153+
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
156154
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
157155

158156
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -162,7 +160,7 @@ def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
162160
attn = attention(q, k, v, pe=pe, mask=attn_mask)
163161
# compute activation in mlp stream, cat again and run second linear layer
164162
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
165-
x += mod.gate * output
163+
x.addcmul_(mod.gate, output)
166164
if x.dtype == torch.float16:
167165
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
168166
return x
@@ -178,6 +176,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
178176
shift, scale = vec
179177
shift = shift.squeeze(1)
180178
scale = scale.squeeze(1)
181-
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
179+
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
182180
x = self.linear(x)
183181
return x

0 commit comments

Comments
 (0)