@@ -1346,62 +1346,27 @@ def __call__(
13461346 audio_embeds_sharded = jax .device_put (audio_embeds , spec )
13471347
13481348 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1349- for i in range (len (timesteps_jax )):
1350- t = timesteps_jax [i ]
1351-
1352- # Isolate input sharding to scan_layers=False to avoid affecting the standard path
1353- latents_jax_sharded = latents_jax
1354- audio_latents_jax_sharded = audio_latents_jax
1355-
1356- if not self .transformer .scan_layers :
1357- activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
1358- latents_jax_sharded = jax .lax .with_sharding_constraint (latents_jax , activation_axis_names )
1359- audio_latents_jax_sharded = jax .lax .with_sharding_constraint (audio_latents_jax , activation_axis_names )
1360-
1361- noise_pred , noise_pred_audio = transformer_forward_pass (
1362- graphdef ,
1363- state ,
1364- latents_jax_sharded ,
1365- audio_latents_jax_sharded ,
1366- t ,
1367- video_embeds_sharded ,
1368- audio_embeds_sharded ,
1369- new_attention_mask ,
1370- new_attention_mask ,
1371- guidance_scale > 1.0 ,
1372- guidance_scale ,
1373- latent_num_frames ,
1374- latent_height ,
1375- latent_width ,
1376- audio_num_frames ,
1377- frame_rate ,
1378- )
1379-
1380- if guidance_scale > 1.0 :
1381- noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1382- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1383- # Audio guidance
1384- noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
1385- noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1386-
1387- latents_step = latents_jax [batch_size :]
1388- audio_latents_step = audio_latents_jax [batch_size :]
1389- else :
1390- latents_step = latents_jax
1391- audio_latents_step = audio_latents_jax
1392-
1393- # Step
1394- latents_step , _ = self .scheduler .step (scheduler_state , noise_pred , t , latents_step , return_dict = False )
1395- audio_latents_step , _ = self .scheduler .step (
1396- scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False
1397- )
1398-
1399- if guidance_scale > 1.0 :
1400- latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1401- audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1402- else :
1403- latents_jax = latents_step
1404- audio_latents_jax = audio_latents_step
1349+ latents_jax , audio_latents_jax = run_diffusion_loop (
1350+ graphdef ,
1351+ state ,
1352+ scheduler_state ,
1353+ timesteps_jax ,
1354+ latents_jax ,
1355+ audio_latents_jax ,
1356+ video_embeds_sharded ,
1357+ audio_embeds_sharded ,
1358+ new_attention_mask ,
1359+ guidance_scale ,
1360+ latent_num_frames ,
1361+ latent_height ,
1362+ latent_width ,
1363+ audio_num_frames ,
1364+ frame_rate ,
1365+ batch_size ,
1366+ self .transformer .scan_layers ,
1367+ self .scheduler .step ,
1368+ tuple (self .config .logical_axis_rules ),
1369+ )
14051370
14061371 # 8. Decode Latents
14071372 if guidance_scale > 1.0 :
@@ -1574,3 +1539,126 @@ def transformer_forward_pass(
15741539 )
15751540
15761541 return noise_pred , noise_pred_audio
1542+
1543+
1544+ @partial (
1545+ jax .jit ,
1546+ static_argnames = (
1547+ "guidance_scale" ,
1548+ "latent_num_frames" ,
1549+ "latent_height" ,
1550+ "latent_width" ,
1551+ "audio_num_frames" ,
1552+ "fps" ,
1553+ "batch_size" ,
1554+ "scan_layers" ,
1555+ "scheduler_step" ,
1556+ "logical_axis_rules" ,
1557+ ),
1558+ )
1559+ def run_diffusion_loop (
1560+ graphdef ,
1561+ state ,
1562+ scheduler_state ,
1563+ timesteps_jax ,
1564+ latents_jax ,
1565+ audio_latents_jax ,
1566+ video_embeds_sharded ,
1567+ audio_embeds_sharded ,
1568+ new_attention_mask ,
1569+ guidance_scale ,
1570+ latent_num_frames ,
1571+ latent_height ,
1572+ latent_width ,
1573+ audio_num_frames ,
1574+ fps ,
1575+ batch_size ,
1576+ scan_layers ,
1577+ scheduler_step ,
1578+ logical_axis_rules ,
1579+ ):
1580+ transformer = nnx .merge (graphdef , state )
1581+
1582+ def scan_body (carry , t ):
1583+ latents , audio_latents , s_state = carry
1584+
1585+ with nn_partitioning .axis_rules (logical_axis_rules ):
1586+ latents_sharded = latents
1587+ audio_latents_sharded = audio_latents
1588+
1589+ if not scan_layers :
1590+ activation_axis_names = nn .logical_to_mesh_axes (
1591+ ("activation_batch" , "activation_length" , "activation_embed" )
1592+ )
1593+ latents_sharded = jax .lax .with_sharding_constraint (
1594+ latents , activation_axis_names
1595+ )
1596+ audio_latents_sharded = jax .lax .with_sharding_constraint (
1597+ audio_latents , activation_axis_names
1598+ )
1599+
1600+ # Expand timestep to batch size
1601+ t_expanded = jnp .expand_dims (t , 0 ).repeat (latents .shape [0 ])
1602+
1603+ noise_pred , noise_pred_audio = transformer (
1604+ hidden_states = latents_sharded ,
1605+ encoder_hidden_states = video_embeds_sharded ,
1606+ timestep = t_expanded ,
1607+ encoder_attention_mask = new_attention_mask ,
1608+ num_frames = latent_num_frames ,
1609+ height = latent_height ,
1610+ width = latent_width ,
1611+ audio_hidden_states = audio_latents_sharded ,
1612+ audio_encoder_hidden_states = audio_embeds_sharded ,
1613+ audio_encoder_attention_mask = new_attention_mask ,
1614+ fps = fps ,
1615+ audio_num_frames = audio_num_frames ,
1616+ return_dict = False ,
1617+ )
1618+
1619+ if guidance_scale > 1.0 :
1620+ noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1621+ noise_pred = noise_pred_uncond + guidance_scale * (
1622+ noise_pred_text - noise_pred_uncond
1623+ )
1624+ # Audio guidance
1625+ (
1626+ noise_pred_audio_uncond ,
1627+ noise_pred_audio_text ,
1628+ ) = jnp .split (noise_pred_audio , 2 , axis = 0 )
1629+ noise_pred_audio = (
1630+ noise_pred_audio_uncond
1631+ + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1632+ )
1633+
1634+ latents_step = latents [batch_size :]
1635+ audio_latents_step = audio_latents [batch_size :]
1636+ else :
1637+ latents_step = latents
1638+ audio_latents_step = audio_latents
1639+
1640+ # Step scheduler
1641+ latents_step , _ = scheduler_step (
1642+ s_state , noise_pred , t , latents_step , return_dict = False
1643+ )
1644+ audio_latents_step , _ = scheduler_step (
1645+ s_state , noise_pred_audio , t , audio_latents_step , return_dict = False
1646+ )
1647+
1648+ if guidance_scale > 1.0 :
1649+ latents_next = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1650+ audio_latents_next = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1651+ else :
1652+ latents_next = latents_step
1653+ audio_latents_next = audio_latents_step
1654+
1655+ new_carry = (latents_next , audio_latents_next , s_state )
1656+ return new_carry , None
1657+
1658+ # Initial carry
1659+ initial_carry = (latents_jax , audio_latents_jax , scheduler_state )
1660+
1661+ # Run scan
1662+ final_carry , _ = jax .lax .scan (scan_body , initial_carry , timesteps_jax )
1663+
1664+ return final_carry [0 ], final_carry [1 ]
0 commit comments