@@ -1096,19 +1096,18 @@ def __call__(
10961096 # 2. Encode inputs (Text)
10971097 import time
10981098 s_text = time .perf_counter ()
1099- with jax .profiler .TraceAnnotation ("Encode Inputs" ):
1100- prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask = self .encode_prompt (
1101- prompt ,
1102- negative_prompt ,
1103- do_classifier_free_guidance = guidance_scale > 1.0 ,
1104- num_videos_per_prompt = num_videos_per_prompt ,
1105- prompt_embeds = prompt_embeds ,
1106- negative_prompt_embeds = negative_prompt_embeds ,
1107- prompt_attention_mask = prompt_attention_mask ,
1108- negative_prompt_attention_mask = negative_prompt_attention_mask ,
1109- max_sequence_length = max_sequence_length ,
1110- dtype = dtype ,
1111- )
1099+ prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask = self .encode_prompt (
1100+ prompt ,
1101+ negative_prompt ,
1102+ do_classifier_free_guidance = guidance_scale > 1.0 ,
1103+ num_videos_per_prompt = num_videos_per_prompt ,
1104+ prompt_embeds = prompt_embeds ,
1105+ negative_prompt_embeds = negative_prompt_embeds ,
1106+ prompt_attention_mask = prompt_attention_mask ,
1107+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
1108+ max_sequence_length = max_sequence_length ,
1109+ dtype = dtype ,
1110+ )
11121111 t_text = time .perf_counter () - s_text
11131112 max_logging .log (f"[Tuning] Prompt encoding took: { t_text :.4f} seconds" )
11141113
@@ -1121,17 +1120,16 @@ def __call__(
11211120
11221121 key_latents , key_audio = jax .random .split (generator )
11231122
1124- with jax .profiler .TraceAnnotation ("Prepare Video Latents" ):
1125- latents = self .prepare_latents (
1126- batch_size = batch_size ,
1127- height = height ,
1128- width = width ,
1129- num_frames = num_frames ,
1130- noise_scale = noise_scale ,
1131- dtype = dtype ,
1132- generator = key_latents ,
1133- latents = latents ,
1134- )
1123+ latents = self .prepare_latents (
1124+ batch_size = batch_size ,
1125+ height = height ,
1126+ width = width ,
1127+ num_frames = num_frames ,
1128+ noise_scale = noise_scale ,
1129+ dtype = dtype ,
1130+ generator = key_latents ,
1131+ latents = latents ,
1132+ )
11351133
11361134 latent_height = height // self .vae_spatial_compression_ratio
11371135 latent_width = width // self .vae_spatial_compression_ratio
@@ -1150,16 +1148,15 @@ def __call__(
11501148 )
11511149 audio_num_frames = round (duration_s * audio_latents_per_second )
11521150
1153- with jax .profiler .TraceAnnotation ("Prepare Audio Latents" ):
1154- audio_latents = self .prepare_audio_latents (
1155- batch_size = batch_size ,
1156- num_channels_latents = audio_channels ,
1157- audio_latent_length = audio_num_frames ,
1158- noise_scale = noise_scale ,
1159- dtype = dtype ,
1160- generator = key_audio ,
1161- latents = audio_latents ,
1162- )
1151+ audio_latents = self .prepare_audio_latents (
1152+ batch_size = batch_size ,
1153+ num_channels_latents = audio_channels ,
1154+ audio_latent_length = audio_num_frames ,
1155+ noise_scale = noise_scale ,
1156+ dtype = dtype ,
1157+ generator = key_audio ,
1158+ latents = audio_latents ,
1159+ )
11631160
11641161 # 5. Prepare Timesteps
11651162 sigmas = jnp .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
@@ -1242,57 +1239,56 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12421239 import time
12431240 total_diffusion_time = 0.0
12441241 for i , t in enumerate (timesteps ):
1245- with jax .profiler .TraceAnnotation (f"Diffusion Step { i } " ):
1246- step_start_time = time .perf_counter ()
1247- noise_pred , noise_pred_audio = transformer_forward_pass (
1248- graphdef ,
1249- state ,
1250- latents_jax ,
1251- audio_latents_jax ,
1252- t ,
1253- video_embeds ,
1254- audio_embeds ,
1255- new_attention_mask ,
1256- new_attention_mask ,
1257- guidance_scale > 1.0 ,
1258- guidance_scale ,
1259- latent_num_frames ,
1260- latent_height ,
1261- latent_width ,
1262- audio_num_frames ,
1263- frame_rate ,
1264- )
1265-
1266- if guidance_scale > 1.0 :
1267- noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1268- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1269- # Audio guidance
1270- noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
1271- noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1272-
1273- latents_step = latents_jax [batch_size :]
1274- audio_latents_step = audio_latents_jax [batch_size :]
1275- else :
1276- latents_step = latents_jax
1277- audio_latents_step = audio_latents_jax
1278-
1279- # Step
1280- latents_step , _ = self .scheduler .step (scheduler_state , noise_pred , t , latents_step , return_dict = False )
1281- audio_latents_step , _ = self .scheduler .step (
1282- scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False
1283- )
1284-
1285- if guidance_scale > 1.0 :
1286- latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1287- audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1288- else :
1289- latents_jax = latents_step
1290- audio_latents_jax = audio_latents_step
1291-
1292- latents_jax .block_until_ready ()
1293- step_duration = time .perf_counter () - step_start_time
1294- total_diffusion_time += step_duration
1295- max_logging .log (f"[Tuning] Diffusion Step { i } e2e time: { step_duration :.4f} seconds" )
1242+ step_start_time = time .perf_counter ()
1243+ noise_pred , noise_pred_audio = transformer_forward_pass (
1244+ graphdef ,
1245+ state ,
1246+ latents_jax ,
1247+ audio_latents_jax ,
1248+ t ,
1249+ video_embeds ,
1250+ audio_embeds ,
1251+ new_attention_mask ,
1252+ new_attention_mask ,
1253+ guidance_scale > 1.0 ,
1254+ guidance_scale ,
1255+ latent_num_frames ,
1256+ latent_height ,
1257+ latent_width ,
1258+ audio_num_frames ,
1259+ frame_rate ,
1260+ )
1261+
1262+ if guidance_scale > 1.0 :
1263+ noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1264+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1265+ # Audio guidance
1266+ noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
1267+ noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1268+
1269+ latents_step = latents_jax [batch_size :]
1270+ audio_latents_step = audio_latents_jax [batch_size :]
1271+ else :
1272+ latents_step = latents_jax
1273+ audio_latents_step = audio_latents_jax
1274+
1275+ # Step
1276+ latents_step , _ = self .scheduler .step (scheduler_state , noise_pred , t , latents_step , return_dict = False )
1277+ audio_latents_step , _ = self .scheduler .step (
1278+ scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False
1279+ )
1280+
1281+ if guidance_scale > 1.0 :
1282+ latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1283+ audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1284+ else :
1285+ latents_jax = latents_step
1286+ audio_latents_jax = audio_latents_step
1287+
1288+ latents_jax .block_until_ready ()
1289+ step_duration = time .perf_counter () - step_start_time
1290+ total_diffusion_time += step_duration
1291+ max_logging .log (f"[Tuning] Diffusion Step { i } e2e time: { step_duration :.4f} seconds" )
12961292 max_logging .log (f"[Tuning] Total pure diffusion time (all steps): { total_diffusion_time :.4f} seconds" )
12971293
12981294 # 8. Decode Latents
@@ -1354,27 +1350,26 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13541350 max_logging .log (f"[Tuning] Failed to apply sharding constraint: { e } " )
13551351
13561352 s_vae = time .perf_counter ()
1357- with jax .profiler .TraceAnnotation ("VAE Decode Video" ):
1358- if getattr (self .vae .config , "timestep_conditioning" , False ):
1359- noise = jax .random .normal (generator , latents .shape , dtype = latents .dtype )
1353+ if getattr (self .vae .config , "timestep_conditioning" , False ):
1354+ noise = jax .random .normal (generator , latents .shape , dtype = latents .dtype )
13601355
1361- if not isinstance (decode_timestep , list ):
1362- decode_timestep = [decode_timestep ] * batch_size
1363- if decode_noise_scale is None :
1364- decode_noise_scale = decode_timestep
1365- elif not isinstance (decode_noise_scale , list ):
1366- decode_noise_scale = [decode_noise_scale ] * batch_size
1356+ if not isinstance (decode_timestep , list ):
1357+ decode_timestep = [decode_timestep ] * batch_size
1358+ if decode_noise_scale is None :
1359+ decode_noise_scale = decode_timestep
1360+ elif not isinstance (decode_noise_scale , list ):
1361+ decode_noise_scale = [decode_noise_scale ] * batch_size
13671362
1368- timestep = jnp .array (decode_timestep , dtype = latents .dtype )
1369- decode_noise_scale = jnp .array (decode_noise_scale , dtype = latents .dtype )[:, None , None , None , None ]
1363+ timestep = jnp .array (decode_timestep , dtype = latents .dtype )
1364+ decode_noise_scale = jnp .array (decode_noise_scale , dtype = latents .dtype )[:, None , None , None , None ]
13701365
1371- latents = (1 - decode_noise_scale ) * latents + decode_noise_scale * noise
1366+ latents = (1 - decode_noise_scale ) * latents + decode_noise_scale * noise
13721367
1373- latents = latents .astype (self .vae .dtype )
1374- video = self .vae .decode (latents , temb = timestep , return_dict = False )[0 ]
1375- else :
1376- latents = latents .astype (self .vae .dtype )
1377- video = self .vae .decode (latents , return_dict = False )[0 ]
1368+ latents = latents .astype (self .vae .dtype )
1369+ video = self .vae .decode (latents , temb = timestep , return_dict = False )[0 ]
1370+ else :
1371+ latents = latents .astype (self .vae .dtype )
1372+ video = self .vae .decode (latents , return_dict = False )[0 ]
13781373 t_vae = time .perf_counter () - s_vae
13791374 max_logging .log (f"[Tuning] VAE decoding took: { t_vae :.4f} seconds" )
13801375 # Post-process video (converts to numpy/PIL)
0 commit comments