@@ -228,6 +228,7 @@ def forward_orig(
228
228
y : Tensor ,
229
229
guidance : Tensor = None ,
230
230
guiding_frame_index = None ,
231
+ ref_latent = None ,
231
232
control = None ,
232
233
transformer_options = {},
233
234
) -> Tensor :
@@ -238,6 +239,14 @@ def forward_orig(
238
239
img = self .img_in (img )
239
240
vec = self .time_in (timestep_embedding (timesteps , 256 , time_factor = 1.0 ).to (img .dtype ))
240
241
242
+ if ref_latent is not None :
243
+ ref_latent_ids = self .img_ids (ref_latent )
244
+ ref_latent = self .img_in (ref_latent )
245
+ img = torch .cat ([ref_latent , img ], dim = - 2 )
246
+ ref_latent_ids [..., 0 ] = - 1
247
+ ref_latent_ids [..., 2 ] += (initial_shape [- 1 ] // self .patch_size [- 1 ])
248
+ img_ids = torch .cat ([ref_latent_ids , img_ids ], dim = - 2 )
249
+
241
250
if guiding_frame_index is not None :
242
251
token_replace_vec = self .time_in (timestep_embedding (guiding_frame_index , 256 , time_factor = 1.0 ))
243
252
vec_ = self .vector_in (y [:, :self .params .vec_in_dim ])
@@ -313,6 +322,8 @@ def block_wrap(args):
313
322
img [:, : img_len ] += add
314
323
315
324
img = img [:, : img_len ]
325
+ if ref_latent is not None :
326
+ img = img [:, ref_latent .shape [1 ]:]
316
327
317
328
img = self .final_layer (img , vec , modulation_dims = modulation_dims ) # (N, T, patch_size ** 2 * out_channels)
318
329
@@ -324,7 +335,7 @@ def block_wrap(args):
324
335
img = img .reshape (initial_shape [0 ], self .out_channels , initial_shape [2 ], initial_shape [3 ], initial_shape [4 ])
325
336
return img
326
337
327
- def forward (self , x , timestep , context , y , guidance = None , attention_mask = None , guiding_frame_index = None , control = None , transformer_options = {}, ** kwargs ):
338
+ def img_ids (self , x ):
328
339
bs , c , t , h , w = x .shape
329
340
patch_size = self .patch_size
330
341
t_len = ((t + (patch_size [0 ] // 2 )) // patch_size [0 ])
@@ -334,7 +345,11 @@ def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, g
334
345
img_ids [:, :, :, 0 ] = img_ids [:, :, :, 0 ] + torch .linspace (0 , t_len - 1 , steps = t_len , device = x .device , dtype = x .dtype ).reshape (- 1 , 1 , 1 )
335
346
img_ids [:, :, :, 1 ] = img_ids [:, :, :, 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).reshape (1 , - 1 , 1 )
336
347
img_ids [:, :, :, 2 ] = img_ids [:, :, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).reshape (1 , 1 , - 1 )
337
- img_ids = repeat (img_ids , "t h w c -> b (t h w) c" , b = bs )
348
+ return repeat (img_ids , "t h w c -> b (t h w) c" , b = bs )
349
+
350
+ def forward (self , x , timestep , context , y , guidance = None , attention_mask = None , guiding_frame_index = None , ref_latent = None , control = None , transformer_options = {}, ** kwargs ):
351
+ bs , c , t , h , w = x .shape
352
+ img_ids = self .img_ids (x )
338
353
txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
339
- out = self .forward_orig (x , img_ids , context , txt_ids , attention_mask , timestep , y , guidance , guiding_frame_index , control , transformer_options )
354
+ out = self .forward_orig (x , img_ids , context , txt_ids , attention_mask , timestep , y , guidance , guiding_frame_index , ref_latent , control = control , transformer_options = transformer_options )
340
355
return out
0 commit comments