Skip to content

Commit c3dd9fe

Browse files
committed
debug
1 parent 5ca96b5 commit c3dd9fe

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#hardware
22
hardware: 'tpu'
33
skip_jax_distributed_system: False
4-
attention: 'ulysses'
4+
attention: 'flash'
55
a2v_attention_kernel: 'dot_product'
66
v2a_attention_kernel: 'dot_product'
77
attention_sharding_uniform: True

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
import jax
2323
import jax.numpy as jnp
24+
import time
2425
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2526
import flax
2627
import flax.linen as nn
@@ -765,9 +766,11 @@ def _get_gemma_prompt_embeds(
765766
prompt_attention_mask = prompt_attention_mask.to(self.text_encoder.device)
766767

767768
with torch.no_grad():
769+
t0 = time.time()
768770
text_encoder_outputs = self.text_encoder(
769771
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
770772
)
773+
print(f"[Timing] Text Encoder time: {time.time() - t0:.2f}s")
771774

772775
text_encoder_hidden_states = text_encoder_outputs.hidden_states
773776
del text_encoder_outputs # Free memory
@@ -1317,8 +1320,10 @@ def __call__(
13171320
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
13181321

13191322
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1323+
transformer_start = time.time()
13201324
for i in range(len(timesteps_jax)):
13211325
t = timesteps_jax[i]
1326+
step_start = time.time()
13221327

13231328
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
13241329
latents_jax_sharded = latents_jax
@@ -1373,6 +1378,10 @@ def __call__(
13731378
else:
13741379
latents_jax = latents_step
13751380
audio_latents_jax = audio_latents_step
1381+
1382+
print(f"[Timing] Step {i} time: {time.time() - step_start:.2f}s")
1383+
1384+
print(f"[Timing] Transformer loop time: {time.time() - transformer_start:.2f}s")
13761385

13771386
# 8. Decode Latents
13781387
if guidance_scale > 1.0:
@@ -1470,10 +1479,13 @@ def __call__(
14701479
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
14711480

14721481
latents = latents.astype(self.vae.dtype)
1482+
vae_start = time.time()
14731483
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
14741484
else:
14751485
latents = latents.astype(self.vae.dtype)
1486+
vae_start = time.time()
14761487
video = self.vae.decode(latents, return_dict=False)[0]
1488+
print(f"[Timing] VAE Decode time: {time.time() - vae_start:.2f}s")
14771489
# Post-process video (converts to numpy/PIL)
14781490
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
14791491
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
@@ -1485,7 +1497,9 @@ def __call__(
14851497

14861498
# Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
14871499
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
1500+
vocoder_start = time.time()
14881501
audio = self.vocoder(generated_mel_spectrograms)
1502+
print(f"[Timing] Vocoder time: {time.time() - vocoder_start:.2f}s")
14891503

14901504
# Convert audio to numpy
14911505
audio = np.array(audio)

0 commit comments

Comments
 (0)