Skip to content

Commit 5f95986

Browse files
committed
fix
1 parent 887f079 commit 5f95986

2 files changed

Lines changed: 16 additions & 20 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,19 +1595,17 @@ def decode(
15951595
keys_slice = jax.random.split(key, latents.shape[0])
15961596
decoded_slices = []
15971597
for i in range(latents.shape[0]):
1598-
with jax.profiler.TraceMe(f"VAE Decode Slice {i}"):
1599-
with jax.named_scope(f"Decode Slice {i}"):
1600-
z_slice = latents[i : i + 1]
1601-
t_slice = temb[i : i + 1] if temb is not None else None
1602-
subkey = keys_slice[i] if keys_slice is not None else None
1603-
res = self._decode(z_slice, t_slice, key=subkey, causal=causal, return_dict=True)
1604-
decoded_slices.append(res.sample)
1598+
with jax.named_scope(f"Decode Slice {i}"):
1599+
z_slice = latents[i : i + 1]
1600+
t_slice = temb[i : i + 1] if temb is not None else None
1601+
subkey = keys_slice[i] if keys_slice is not None else None
1602+
res = self._decode(z_slice, t_slice, key=subkey, causal=causal, return_dict=True)
1603+
decoded_slices.append(res.sample)
16051604

16061605
dec = jnp.concatenate(decoded_slices, axis=0)
16071606
else:
1608-
with jax.profiler.TraceMe("VAE Decode Full Batch"):
1609-
with jax.named_scope("Decode Full Batch"):
1610-
dec = self._decode(latents, temb, key=key, causal=causal, return_dict=True).sample
1607+
with jax.named_scope("Decode Full Batch"):
1608+
dec = self._decode(latents, temb, key=key, causal=causal, return_dict=True).sample
16111609

16121610
if not return_dict:
16131611
return (dec,)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,32 +1374,30 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13741374
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
13751375

13761376
latents = latents.astype(self.vae.dtype)
1377-
max_logging.log(f"[Profiling] Latents device before VAE Decode: {latents.devices()}")
1377+
13781378
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
1379-
max_logging.log(f"[Profiling] Video device after VAE Decode: {getattr(video, 'devices', lambda: 'Unknown')()}")
1379+
13801380
else:
13811381
latents = latents.astype(self.vae.dtype)
1382-
max_logging.log(f"[Profiling] Latents device before VAE Decode (else): {latents.devices()}")
1382+
13831383
video = self.vae.decode(latents, return_dict=False)[0]
1384-
max_logging.log(f"[Profiling] Video device after VAE Decode (else): {getattr(video, 'devices', lambda: 'Unknown')()}")
1384+
13851385
t_vae = time.perf_counter() - s_vae
13861386
max_logging.log(f"[Tuning] VAE decoding took: {t_vae:.4f} seconds")
13871387
# Post-process video (converts to numpy/PIL)
13881388
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
1389-
with jax.profiler.TraceMe("Video Post-processing"):
1390-
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
1391-
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
1389+
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
1390+
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
13921391

13931392
# Decode Audio
13941393
audio_latents = audio_latents.astype(self.audio_vae.dtype)
1395-
with jax.profiler.TraceMe("Audio VAE Decode"):
1396-
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
1394+
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
13971395

13981396
# Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
13991397
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
14001398

14011399
s_vocoder = time.perf_counter()
1402-
with jax.profiler.TraceMe("Vocoder Audio Generation"):
1400+
with jax.named_scope("Vocoder Audio Generation"):
14031401
audio = self.vocoder(generated_mel_spectrograms)
14041402
t_vocoder = time.perf_counter() - s_vocoder
14051403
max_logging.log(f"[Tuning] Vocoder took: {t_vocoder:.4f} seconds")

0 commit comments

Comments
 (0)