Skip to content

Commit 4a9014e

Browse files
Hunyuan Custom initial untested implementation. (comfyanonymous#8101)
1 parent 8a7c894 commit 4a9014e

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

comfy/ldm/hunyuan_video/model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def forward_orig(
228228
y: Tensor,
229229
guidance: Tensor = None,
230230
guiding_frame_index=None,
231+
ref_latent=None,
231232
control=None,
232233
transformer_options={},
233234
) -> Tensor:
@@ -238,6 +239,14 @@ def forward_orig(
238239
img = self.img_in(img)
239240
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
240241

242+
if ref_latent is not None:
243+
ref_latent_ids = self.img_ids(ref_latent)
244+
ref_latent = self.img_in(ref_latent)
245+
img = torch.cat([ref_latent, img], dim=-2)
246+
ref_latent_ids[..., 0] = -1
247+
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
248+
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
249+
241250
if guiding_frame_index is not None:
242251
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
243252
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
@@ -313,6 +322,8 @@ def block_wrap(args):
313322
img[:, : img_len] += add
314323

315324
img = img[:, : img_len]
325+
if ref_latent is not None:
326+
img = img[:, ref_latent.shape[1]:]
316327

317328
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
318329

@@ -324,7 +335,7 @@ def block_wrap(args):
324335
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
325336
return img
326337

327-
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
338+
def img_ids(self, x):
328339
bs, c, t, h, w = x.shape
329340
patch_size = self.patch_size
330341
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -334,7 +345,11 @@ def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, g
334345
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
335346
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
336347
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
337-
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
348+
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
349+
350+
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
351+
bs, c, t, h, w = x.shape
352+
img_ids = self.img_ids(x)
338353
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
339-
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
354+
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
340355
return out

comfy/model_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,10 @@ def extra_conds(self, **kwargs):
924924
if guiding_frame_index is not None:
925925
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
926926

927+
ref_latent = kwargs.get("ref_latent", None)
928+
if ref_latent is not None:
929+
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
930+
927931
return out
928932

929933
def scale_latent_inpaint(self, latent_image, **kwargs):

comfy_extras/nodes_hunyuan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def INPUT_TYPES(s):
7777
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
7878
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
7979
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
80-
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
80+
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
8181
},
8282
"optional": {"start_image": ("IMAGE", ),
8383
}}
@@ -101,10 +101,12 @@ def encode(self, positive, vae, width, height, length, batch_size, guidance_type
101101

102102
if guidance_type == "v1 (concat)":
103103
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
104-
else:
104+
elif guidance_type == "v2 (replace)":
105105
cond = {'guiding_frame_index': 0}
106106
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
107107
out_latent["noise_mask"] = mask
108+
elif guidance_type == "custom":
109+
cond = {"ref_latent": concat_latent_image}
108110

109111
positive = node_helpers.conditioning_set_values(positive, cond)
110112

0 commit comments

Comments
 (0)