Skip to content

Commit bf69bdf

Browse files
committed
Adding transpose for audio_latents before decoding
1 parent 34d4e0e commit bf69bdf

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,11 @@ def prepare_latents(
10161016
scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, "scaling_factor") else 1.0
10171017

10181018
latents = self._normalize_latents(latents, latents_mean, latents_std, scaling_factor)
1019+
1020+
# If latents came from VAE directly, they are (B, T, H, W, C).
1021+
# The packing and unpacking mechanisms expect (B, C, T, H, W).
1022+
latents = latents.transpose(0, 4, 1, 2, 3)
1023+
10191024
latents = self._pack_latents(
10201025
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
10211026
)
@@ -1308,6 +1313,9 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13081313
self.vae.config.scaling_factor
13091314
)
13101315

1316+
# VAE expects channels last (B, T, H, W, C) but unpack returns (B, C, T, H, W)
1317+
latents = latents.transpose(0, 2, 3, 4, 1)
1318+
13111319
# Denormalize and Unpack Audio (Order important: Denorm THEN Unpack)
13121320
audio_latents = self._denormalize_audio_latents(
13131321
audio_latents_jax,
@@ -1324,6 +1332,10 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13241332
num_mel_bins=latent_mel_bins
13251333
)
13261334

1335+
# Audio VAE expects channels last (B, T, F, C) but unpack returns (B, C, T, F)
1336+
if audio_latents.ndim == 4:
1337+
audio_latents = audio_latents.transpose(0, 2, 3, 1)
1338+
13271339
if output_type == "latent":
13281340
return LTX2PipelineOutput(frames=latents, audio=audio_latents)
13291341

0 commit comments

Comments
 (0)