@@ -1140,6 +1140,10 @@ def __call__(
11401140 latents = latents ,
11411141 )
11421142
1143+ latent_height = height // self .vae_spatial_compression_ratio
1144+ latent_width = width // self .vae_spatial_compression_ratio
1145+ latent_num_frames = (num_frames - 1 ) // self .vae_temporal_compression_ratio + 1
1146+
11431147 # 4. Prepare Audio Latents
11441148 audio_channels = (
11451149 self .audio_vae .config .latent_channels
@@ -1256,9 +1260,9 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12561260 prompt_attention_mask_jax ,
12571261 guidance_scale > 1.0 ,
12581262 guidance_scale ,
1259- num_frames ,
1260- height ,
1261- width ,
1263+ latent_num_frames ,
1264+ latent_height ,
1265+ latent_width ,
12621266 audio_num_frames ,
12631267 frame_rate ,
12641268 )
@@ -1358,7 +1362,7 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13581362
13591363 return LTX2PipelineOutput (frames = video , audio = audio )
13601364
1361- @partial (jax .jit , static_argnames = ("do_classifier_free_guidance" , "guidance_scale" , "num_frames " , "height " , "width " , "audio_num_frames" , "fps" ))
1365+ @partial (jax .jit , static_argnames = ("do_classifier_free_guidance" , "guidance_scale" , "latent_num_frames " , "latent_height " , "latent_width " , "audio_num_frames" , "fps" ))
13621366def transformer_forward_pass (
13631367 graphdef ,
13641368 state ,
@@ -1371,9 +1375,9 @@ def transformer_forward_pass(
13711375 audio_encoder_attention_mask ,
13721376 do_classifier_free_guidance ,
13731377 guidance_scale ,
1374- num_frames ,
1375- height ,
1376- width ,
1378+ latent_num_frames ,
1379+ latent_height ,
1380+ latent_width ,
13771381 audio_num_frames ,
13781382 fps ,
13791383):
@@ -1387,9 +1391,9 @@ def transformer_forward_pass(
13871391 encoder_hidden_states = encoder_hidden_states ,
13881392 timestep = timestep ,
13891393 encoder_attention_mask = encoder_attention_mask ,
1390- num_frames = num_frames ,
1391- height = height ,
1392- width = width ,
1394+ num_frames = latent_num_frames ,
1395+ height = latent_height ,
1396+ width = latent_width ,
13931397 audio_hidden_states = audio_latents ,
13941398 audio_encoder_hidden_states = audio_encoder_hidden_states ,
13951399 audio_encoder_attention_mask = audio_encoder_attention_mask ,
0 commit comments