Skip to content

Commit 00c1609

Browse files
committed
removing debug
1 parent 19a2dc6 commit 00c1609

1 file changed

Lines changed: 0 additions & 14 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import jax
2323
import jax.numpy as jnp
24-
import time
2524
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2625
import flax
2726
import flax.linen as nn
@@ -766,11 +765,9 @@ def _get_gemma_prompt_embeds(
766765
prompt_attention_mask = prompt_attention_mask.to(self.text_encoder.device)
767766

768767
with torch.no_grad():
769-
t0 = time.time()
770768
text_encoder_outputs = self.text_encoder(
771769
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
772770
)
773-
print(f"[Timing] Text Encoder time: {time.time() - t0:.2f}s")
774771

775772
text_encoder_hidden_states = text_encoder_outputs.hidden_states
776773
del text_encoder_outputs # Free memory
@@ -1349,10 +1346,8 @@ def __call__(
13491346
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
13501347

13511348
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
1352-
transformer_start = time.time()
13531349
for i in range(len(timesteps_jax)):
13541350
t = timesteps_jax[i]
1355-
step_start = time.time()
13561351

13571352
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
13581353
latents_jax_sharded = latents_jax
@@ -1407,10 +1402,6 @@ def __call__(
14071402
else:
14081403
latents_jax = latents_step
14091404
audio_latents_jax = audio_latents_step
1410-
1411-
print(f"[Timing] Step {i} time: {time.time() - step_start:.2f}s")
1412-
1413-
print(f"[Timing] Transformer loop time: {time.time() - transformer_start:.2f}s")
14141405

14151406
# 8. Decode Latents
14161407
if guidance_scale > 1.0:
@@ -1508,13 +1499,10 @@ def __call__(
15081499
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
15091500

15101501
latents = latents.astype(self.vae.dtype)
1511-
vae_start = time.time()
15121502
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
15131503
else:
15141504
latents = latents.astype(self.vae.dtype)
1515-
vae_start = time.time()
15161505
video = self.vae.decode(latents, return_dict=False)[0]
1517-
print(f"[Timing] VAE Decode time: {time.time() - vae_start:.2f}s")
15181506
# Post-process video (converts to numpy/PIL)
15191507
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
15201508
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
@@ -1526,9 +1514,7 @@ def __call__(
15261514

15271515
# Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
15281516
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
1529-
vocoder_start = time.time()
15301517
audio = self.vocoder(generated_mel_spectrograms)
1531-
print(f"[Timing] Vocoder time: {time.time() - vocoder_start:.2f}s")
15321518

15331519
# Convert audio to numpy
15341520
audio = np.array(audio)

0 commit comments

Comments
 (0)