Skip to content

Commit a95f1d0

Browse files
committed
transposing vae decoded output
1 parent 34199d3 commit a95f1d0

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,8 +1388,8 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13881388
latents = latents.astype(self.vae.dtype)
13891389
video = self.vae.decode(latents, return_dict=False)[0]
13901390
# Post-process video (converts to numpy/PIL)
1391-
# We need to pass numpy to postprocess_video usually, checking if it handles JAX
1392-
video_np = np.array(video)
1391+
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
1392+
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
13931393
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
13941394

13951395
# Decode Audio

0 commit comments

Comments
 (0)