@@ -1392,17 +1392,21 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13921392 max_logging .log (f"[Tuning] VAE decoding took: { t_vae :.4f} seconds" )
13931393 # Post-process video (converts to numpy/PIL)
13941394 # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
1395- video_np = np .array (video ).transpose (0 , 4 , 1 , 2 , 3 )
1396- video = self .video_processor .postprocess_video (torch .from_numpy (video_np ), output_type = output_type )
1395+ with jax .profiler .TraceMe ("Video Post-processing" ):
1396+ video_np = np .array (video ).transpose (0 , 4 , 1 , 2 , 3 )
1397+ video = self .video_processor .postprocess_video (torch .from_numpy (video_np ), output_type = output_type )
13971398
13981399 # Decode Audio
13991400 audio_latents = audio_latents .astype (self .audio_vae .dtype )
1400- generated_mel_spectrograms = self .audio_vae .decode (audio_latents , return_dict = False )[0 ]
1401+ with jax .profiler .TraceMe ("Audio VAE Decode" ):
1402+ generated_mel_spectrograms = self .audio_vae .decode (audio_latents , return_dict = False )[0 ]
14011403
14021404 # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
14031405 generated_mel_spectrograms = generated_mel_spectrograms .transpose (0 , 3 , 1 , 2 )
1406+
14041407 s_vocoder = time .perf_counter ()
1405- audio = self .vocoder (generated_mel_spectrograms )
1408+ with jax .profiler .TraceMe ("Vocoder Audio Generation" ):
1409+ audio = self .vocoder (generated_mel_spectrograms )
14061410 t_vocoder = time .perf_counter () - s_vocoder
14071411 max_logging .log (f"[Tuning] Vocoder took: { t_vocoder :.4f} seconds" )
14081412
0 commit comments