Skip to content

Commit 83972c8

Browse files
committed
vae tiled decode disable + debug statements
1 parent 5f2affa commit 83972c8

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,9 +1379,12 @@ def tiled_decode(
13791379
B, T, H, W, C = z.shape
13801380
sample_height = H * self.spatial_compression_ratio
13811381
sample_width = W * self.spatial_compression_ratio
1382+
print(f"DEBUG: VAE tiled_decode called with hidden shape H={H}, W={W}")
1383+
print(f"DEBUG: target sample_height={sample_height}, sample_width={sample_width}")
13821384

13831385
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
13841386
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1387+
print(f"DEBUG: tile_latent_min_height={tile_latent_min_height}, tile_latent_min_width={tile_latent_min_width}")
13851388
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
13861389
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
13871390

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
469469
mesh=mesh,
470470
**vae_kwargs,
471471
)
472+
vae.tile_sample_min_width = 1024
472473
return vae
473474

474475
p_model_factory = partial(create_model, config=config)
@@ -1724,6 +1725,7 @@ def convert_to_vel(lat, x0):
17241725
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
17251726
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
17261727
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
1728+
print(f"DEBUG: final video shape: {np.array(video).shape}")
17271729

17281730
# Decode Audio
17291731
audio_latents = audio_latents.astype(self.audio_vae.dtype)
@@ -1739,6 +1741,7 @@ def convert_to_vel(lat, x0):
17391741

17401742
# Convert audio to numpy
17411743
audio = np.array(audio)
1744+
print(f"DEBUG: final audio shape: {audio.shape}")
17421745

17431746
return LTX2PipelineOutput(frames=video, audio=audio)
17441747

0 commit comments

Comments
 (0)