2121import torch
2222import jax
2323import jax .numpy as jnp
24- import time
2524from jax .sharding import Mesh , NamedSharding , PartitionSpec as P
2625import flax
2726import flax .linen as nn
@@ -766,11 +765,9 @@ def _get_gemma_prompt_embeds(
766765 prompt_attention_mask = prompt_attention_mask .to (self .text_encoder .device )
767766
768767 with torch .no_grad ():
769- t0 = time .time ()
770768 text_encoder_outputs = self .text_encoder (
771769 input_ids = text_input_ids , attention_mask = prompt_attention_mask , output_hidden_states = True
772770 )
773- print (f"[Timing] Text Encoder time: { time .time () - t0 :.2f} s" )
774771
775772 text_encoder_hidden_states = text_encoder_outputs .hidden_states
776773 del text_encoder_outputs # Free memory
@@ -1349,10 +1346,8 @@ def __call__(
13491346 audio_embeds_sharded = jax .device_put (audio_embeds , spec )
13501347
13511348 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1352- transformer_start = time .time ()
13531349 for i in range (len (timesteps_jax )):
13541350 t = timesteps_jax [i ]
1355- step_start = time .time ()
13561351
13571352 # Isolate input sharding to scan_layers=False to avoid affecting the standard path
13581353 latents_jax_sharded = latents_jax
@@ -1407,10 +1402,6 @@ def __call__(
14071402 else :
14081403 latents_jax = latents_step
14091404 audio_latents_jax = audio_latents_step
1410-
1411- print (f"[Timing] Step { i } time: { time .time () - step_start :.2f} s" )
1412-
1413- print (f"[Timing] Transformer loop time: { time .time () - transformer_start :.2f} s" )
14141405
14151406 # 8. Decode Latents
14161407 if guidance_scale > 1.0 :
@@ -1508,13 +1499,10 @@ def __call__(
15081499 latents = (1 - decode_noise_scale ) * latents + decode_noise_scale * noise
15091500
15101501 latents = latents .astype (self .vae .dtype )
1511- vae_start = time .time ()
15121502 video = self .vae .decode (latents , temb = timestep , return_dict = False )[0 ]
15131503 else :
15141504 latents = latents .astype (self .vae .dtype )
1515- vae_start = time .time ()
15161505 video = self .vae .decode (latents , return_dict = False )[0 ]
1517- print (f"[Timing] VAE Decode time: { time .time () - vae_start :.2f} s" )
15181506 # Post-process video (converts to numpy/PIL)
15191507 # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
15201508 video_np = np .array (video ).transpose (0 , 4 , 1 , 2 , 3 )
@@ -1526,9 +1514,7 @@ def __call__(
15261514
15271515 # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
15281516 generated_mel_spectrograms = generated_mel_spectrograms .transpose (0 , 3 , 1 , 2 )
1529- vocoder_start = time .time ()
15301517 audio = self .vocoder (generated_mel_spectrograms )
1531- print (f"[Timing] Vocoder time: { time .time () - vocoder_start :.2f} s" )
15321518
15331519 # Convert audio to numpy
15341520 audio = np .array (audio )
0 commit comments