@@ -1374,32 +1374,30 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13741374 latents = (1 - decode_noise_scale ) * latents + decode_noise_scale * noise
13751375
13761376 latents = latents .astype (self .vae .dtype )
1377- max_logging . log ( f"[Profiling] Latents device before VAE Decode: { latents . devices () } " )
1377+
13781378 video = self .vae .decode (latents , temb = timestep , return_dict = False )[0 ]
1379- max_logging . log ( f"[Profiling] Video device after VAE Decode: { getattr ( video , 'devices' , lambda : 'Unknown' )() } " )
1379+
13801380 else :
13811381 latents = latents .astype (self .vae .dtype )
1382- max_logging . log ( f"[Profiling] Latents device before VAE Decode (else): { latents . devices () } " )
1382+
13831383 video = self .vae .decode (latents , return_dict = False )[0 ]
1384- max_logging . log ( f"[Profiling] Video device after VAE Decode (else): { getattr ( video , 'devices' , lambda : 'Unknown' )() } " )
1384+
13851385 t_vae = time .perf_counter () - s_vae
13861386 max_logging .log (f"[Tuning] VAE decoding took: { t_vae :.4f} seconds" )
13871387 # Post-process video (converts to numpy/PIL)
13881388 # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
1389- with jax .profiler .TraceMe ("Video Post-processing" ):
1390- video_np = np .array (video ).transpose (0 , 4 , 1 , 2 , 3 )
1391- video = self .video_processor .postprocess_video (torch .from_numpy (video_np ), output_type = output_type )
1389+ video_np = np .array (video ).transpose (0 , 4 , 1 , 2 , 3 )
1390+ video = self .video_processor .postprocess_video (torch .from_numpy (video_np ), output_type = output_type )
13921391
13931392 # Decode Audio
13941393 audio_latents = audio_latents .astype (self .audio_vae .dtype )
1395- with jax .profiler .TraceMe ("Audio VAE Decode" ):
1396- generated_mel_spectrograms = self .audio_vae .decode (audio_latents , return_dict = False )[0 ]
1394+ generated_mel_spectrograms = self .audio_vae .decode (audio_latents , return_dict = False )[0 ]
13971395
13981396 # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins)
13991397 generated_mel_spectrograms = generated_mel_spectrograms .transpose (0 , 3 , 1 , 2 )
14001398
14011399 s_vocoder = time .perf_counter ()
1402- with jax .profiler . TraceMe ("Vocoder Audio Generation" ):
1400+ with jax .named_scope ("Vocoder Audio Generation" ):
14031401 audio = self .vocoder (generated_mel_spectrograms )
14041402 t_vocoder = time .perf_counter () - s_vocoder
14051403 max_logging .log (f"[Tuning] Vocoder took: { t_vocoder :.4f} seconds" )
0 commit comments