2121import torch
2222import jax
2323import jax .numpy as jnp
24+ import time
2425from jax .sharding import Mesh , NamedSharding , PartitionSpec as P
2526import flax
2627import flax .linen as nn
@@ -765,9 +766,11 @@ def _get_gemma_prompt_embeds(
765766 prompt_attention_mask = prompt_attention_mask .to (self .text_encoder .device )
766767
767768 with torch .no_grad ():
769+ t0 = time .time ()
768770 text_encoder_outputs = self .text_encoder (
769771 input_ids = text_input_ids , attention_mask = prompt_attention_mask , output_hidden_states = True
770772 )
773+ print (f"[Timing] Text Encoder time: { time .time () - t0 :.2f} s" )
771774
772775 text_encoder_hidden_states = text_encoder_outputs .hidden_states
773776 del text_encoder_outputs # Free memory
@@ -1317,8 +1320,10 @@ def __call__(
13171320 audio_embeds_sharded = jax .device_put (audio_embeds , spec )
13181321
13191322 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1323+ transformer_start = time .time ()
13201324 for i in range (len (timesteps_jax )):
13211325 t = timesteps_jax [i ]
1326+ step_start = time .time ()
13221327
13231328 # Isolate input sharding to scan_layers=False to avoid affecting the standard path
13241329 latents_jax_sharded = latents_jax
@@ -1373,6 +1378,10 @@ def __call__(
13731378 else :
13741379 latents_jax = latents_step
13751380 audio_latents_jax = audio_latents_step
1381+
1382+ print (f"[Timing] Step { i } time: { time .time () - step_start :.2f} s" )
1383+
1384+ print (f"[Timing] Transformer loop time: { time .time () - transformer_start :.2f} s" )
13761385
13771386 # 8. Decode Latents
13781387 if guidance_scale > 1.0 :
@@ -1470,10 +1479,13 @@ def __call__(
14701479 latents = (1 - decode_noise_scale ) * latents + decode_noise_scale * noise
14711480
14721481 latents = latents .astype (self .vae .dtype )
1482+ vae_start = time .time ()
14731483 video = self .vae .decode (latents , temb = timestep , return_dict = False )[0 ]
14741484 else :
14751485 latents = latents .astype (self .vae .dtype )
1486+ vae_start = time .time ()
14761487 video = self .vae .decode (latents , return_dict = False )[0 ]
1488+ print (f"[Timing] VAE Decode time: { time .time () - vae_start :.2f} s" )
14771489 # Post-process video (converts to numpy/PIL)
14781490 # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
14791491 video_np = np .array (video ).transpose (0 , 4 , 1 , 2 , 3 )
@@ -1485,7 +1497,9 @@ def __call__(
14851497
14861498 # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
14871499 generated_mel_spectrograms = generated_mel_spectrograms .transpose (0 , 3 , 1 , 2 )
1500+ vocoder_start = time .time ()
14881501 audio = self .vocoder (generated_mel_spectrograms )
1502+ print (f"[Timing] Vocoder time: { time .time () - vocoder_start :.2f} s" )
14891503
14901504 # Convert audio to numpy
14911505 audio = np .array (audio )
0 commit comments