Skip to content

Commit e8f964e

Browse files
committed
annotations for cpu side ops
1 parent b148d5a commit e8f964e

2 files changed

Lines changed: 18 additions & 12 deletions

File tree

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,17 +1595,19 @@ def decode(
15951595
keys_slice = jax.random.split(key, latents.shape[0])
15961596
decoded_slices = []
15971597
for i in range(latents.shape[0]):
1598-
with jax.named_scope(f"Decode Slice {i}"):
1599-
z_slice = latents[i : i + 1]
1600-
t_slice = temb[i : i + 1] if temb is not None else None
1601-
subkey = keys_slice[i] if keys_slice is not None else None
1602-
res = self._decode(z_slice, t_slice, key=subkey, causal=causal, return_dict=True)
1603-
decoded_slices.append(res.sample)
1598+
with jax.profiler.TraceMe(f"VAE Decode Slice {i}"):
1599+
with jax.named_scope(f"Decode Slice {i}"):
1600+
z_slice = latents[i : i + 1]
1601+
t_slice = temb[i : i + 1] if temb is not None else None
1602+
subkey = keys_slice[i] if keys_slice is not None else None
1603+
res = self._decode(z_slice, t_slice, key=subkey, causal=causal, return_dict=True)
1604+
decoded_slices.append(res.sample)
16041605

16051606
dec = jnp.concatenate(decoded_slices, axis=0)
16061607
else:
1607-
with jax.named_scope("Decode Full Batch"):
1608-
dec = self._decode(latents, temb, key=key, causal=causal, return_dict=True).sample
1608+
with jax.profiler.TraceMe("VAE Decode Full Batch"):
1609+
with jax.named_scope("Decode Full Batch"):
1610+
dec = self._decode(latents, temb, key=key, causal=causal, return_dict=True).sample
16091611

16101612
if not return_dict:
16111613
return (dec,)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,17 +1392,21 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13921392
max_logging.log(f"[Tuning] VAE decoding took: {t_vae:.4f} seconds")
13931393
# Post-process video (converts to numpy/PIL)
13941394
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
1395-
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
1396-
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
1395+
with jax.profiler.TraceMe("Video Post-processing"):
1396+
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
1397+
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
13971398

13981399
# Decode Audio
13991400
audio_latents = audio_latents.astype(self.audio_vae.dtype)
1400-
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
1401+
with jax.profiler.TraceMe("Audio VAE Decode"):
1402+
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
14011403

14021404
# Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
14031405
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
1406+
14041407
s_vocoder = time.perf_counter()
1405-
audio = self.vocoder(generated_mel_spectrograms)
1408+
with jax.profiler.TraceMe("Vocoder Audio Generation"):
1409+
audio = self.vocoder(generated_mel_spectrograms)
14061410
t_vocoder = time.perf_counter() - s_vocoder
14071411
max_logging.log(f"[Tuning] Vocoder took: {t_vocoder:.4f} seconds")
14081412

0 commit comments

Comments
 (0)