@@ -1016,6 +1016,11 @@ def prepare_latents(
10161016 scaling_factor = self .vae .config .scaling_factor if hasattr (self .vae .config , "scaling_factor" ) else 1.0
10171017
10181018 latents = self ._normalize_latents (latents , latents_mean , latents_std , scaling_factor )
1019+
1020+ # If latents came from VAE directly, they are (B, T, H, W, C).
1021+ # The packing and unpacking mechanisms expect (B, C, T, H, W).
1022+ latents = latents .transpose (0 , 4 , 1 , 2 , 3 )
1023+
10191024 latents = self ._pack_latents (
10201025 latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size
10211026 )
@@ -1308,6 +1313,9 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13081313 self .vae .config .scaling_factor
13091314 )
13101315
1316+ # VAE expects channels last (B, T, H, W, C) but unpack returns (B, C, T, H, W)
1317+ latents = latents .transpose (0 , 2 , 3 , 4 , 1 )
1318+
13111319 # Denormalize and Unpack Audio (Order important: Denorm THEN Unpack)
13121320 audio_latents = self ._denormalize_audio_latents (
13131321 audio_latents_jax ,
@@ -1324,6 +1332,10 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13241332 num_mel_bins = latent_mel_bins
13251333 )
13261334
1335+ # Audio VAE expects channels last (B, T, F, C) but unpack returns (B, C, T, F)
1336+ if audio_latents .ndim == 4 :
1337+ audio_latents = audio_latents .transpose (0 , 2 , 3 , 1 )
1338+
13271339 if output_type == "latent" :
13281340 return LTX2PipelineOutput (frames = latents , audio = audio_latents )
13291341
0 commit comments