Skip to content

Commit d1f1e7c

Browse files
committed
debug for tiled decode+ audio
1 parent 83972c8 commit d1f1e7c

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,7 @@ def enable_tiling(
12771277

12781278
def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12791279
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
1280+
print(f"DEBUG: blend_v called with a.shape={a.shape}, b.shape={b.shape}, blend_extent={blend_extent}")
12801281
if blend_extent <= 0:
12811282
return b
12821283

@@ -1289,6 +1290,7 @@ def blend_v(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12891290

12901291
def blend_h(self, a: jax.Array, b: jax.Array, blend_extent: int) -> jax.Array:
12911292
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1293+
print(f"DEBUG: blend_h called with a.shape={a.shape}, b.shape={b.shape}, blend_extent={blend_extent}")
12921294
if blend_extent <= 0:
12931295
return b
12941296

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,6 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
469469
mesh=mesh,
470470
**vae_kwargs,
471471
)
472-
vae.tile_sample_min_width = 1024
473472
return vae
474473

475474
p_model_factory = partial(create_model, config=config)
@@ -1717,6 +1716,7 @@ def convert_to_vel(lat, x0):
17171716
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
17181717

17191718
latents = latents.astype(self.vae.dtype)
1719+
print(f"DEBUG: latents shape before VAE decode: {latents.shape}")
17201720
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
17211721
else:
17221722
latents = latents.astype(self.vae.dtype)
@@ -1742,6 +1742,9 @@ def convert_to_vel(lat, x0):
17421742
# Convert audio to numpy
17431743
audio = np.array(audio)
17441744
print(f"DEBUG: final audio shape: {audio.shape}")
1745+
print(f"DEBUG: audio min: {audio.min()}")
1746+
print(f"DEBUG: audio max: {audio.max()}")
1747+
print(f"DEBUG: audio mean: {audio.mean()}")
17451748

17461749
return LTX2PipelineOutput(frames=video, audio=audio)
17471750

0 commit comments

Comments
 (0)