@@ -1346,27 +1346,83 @@ def __call__(
13461346 audio_embeds_sharded = jax .device_put (audio_embeds , spec )
13471347
13481348 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1349- latents_jax , audio_latents_jax = run_diffusion_loop (
1350- graphdef ,
1351- state ,
1352- scheduler_state ,
1353- timesteps_jax ,
1354- latents_jax ,
1355- audio_latents_jax ,
1356- video_embeds_sharded ,
1357- audio_embeds_sharded ,
1358- new_attention_mask ,
1359- guidance_scale ,
1360- latent_num_frames ,
1361- latent_height ,
1362- latent_width ,
1363- audio_num_frames ,
1364- frame_rate ,
1365- batch_size ,
1366- self .transformer .scan_layers ,
1367- self .scheduler .step ,
1368- tuple (tuple (rule ) if isinstance (rule , list ) else rule for rule in self .config .logical_axis_rules ),
1369- )
1349+
1350+ scan_diffusion_loop = getattr (self .config , "scan_diffusion_loop" , True )
1351+
1352+ if scan_diffusion_loop :
1353+ latents_jax , audio_latents_jax = run_diffusion_loop (
1354+ graphdef ,
1355+ state ,
1356+ scheduler_state ,
1357+ timesteps_jax ,
1358+ latents_jax ,
1359+ audio_latents_jax ,
1360+ video_embeds_sharded ,
1361+ audio_embeds_sharded ,
1362+ new_attention_mask ,
1363+ guidance_scale ,
1364+ latent_num_frames ,
1365+ latent_height ,
1366+ latent_width ,
1367+ audio_num_frames ,
1368+ frame_rate ,
1369+ batch_size ,
1370+ self .transformer .scan_layers ,
1371+ self .scheduler .step ,
1372+ tuple (tuple (rule ) if isinstance (rule , list ) else rule for rule in self .config .logical_axis_rules ),
1373+ )
1374+ else :
1375+ # Old Python loop path
1376+ latents_jax = latents_jax .astype (jnp .float32 )
1377+ audio_latents_jax = audio_latents_jax .astype (jnp .float32 )
1378+
1379+ for t in timesteps_jax :
1380+ noise_pred , noise_pred_audio = transformer_forward_pass (
1381+ graphdef ,
1382+ state ,
1383+ latents_jax ,
1384+ audio_latents_jax ,
1385+ t ,
1386+ video_embeds_sharded ,
1387+ audio_embeds_sharded ,
1388+ new_attention_mask ,
1389+ new_attention_mask ,
1390+ guidance_scale > 1.0 ,
1391+ guidance_scale ,
1392+ latent_num_frames ,
1393+ latent_height ,
1394+ latent_width ,
1395+ audio_num_frames ,
1396+ frame_rate ,
1397+ )
1398+
1399+ if guidance_scale > 1.0 :
1400+ noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1401+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1402+
1403+ noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
1404+ noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1405+
1406+ latents_step = latents_jax [batch_size :]
1407+ audio_latents_step = audio_latents_jax [batch_size :]
1408+ else :
1409+ latents_step = latents_jax
1410+ audio_latents_step = audio_latents_jax
1411+
1412+ latents_step , _ = self .scheduler .step (scheduler_state , noise_pred , t , latents_step , return_dict = False )
1413+ latents_step = latents_step .astype (jnp .float32 )
1414+
1415+ audio_latents_step , _ = self .scheduler .step (
1416+ scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False
1417+ )
1418+ audio_latents_step = audio_latents_step .astype (jnp .float32 )
1419+
1420+ if guidance_scale > 1.0 :
1421+ latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1422+ audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1423+ else :
1424+ latents_jax = latents_step
1425+ audio_latents_jax = audio_latents_step
13701426
13711427 # 8. Decode Latents
13721428 if guidance_scale > 1.0 :
0 commit comments