@@ -832,29 +832,29 @@ def encode_prompt(
832832 if do_classifier_free_guidance and negative_prompt_embeds is None :
833833 negative_prompt = negative_prompt or ""
834834 negative_prompt = [negative_prompt ] * batch_size if isinstance (negative_prompt , str ) else negative_prompt
835-
835+
836836 if isinstance (prompt , str ):
837837 prompt = [prompt ]
838-
838+
839839 combined_prompts = prompt + negative_prompt
840-
840+
841841 combined_embeds , combined_mask = self ._get_gemma_prompt_embeds (
842842 prompt = combined_prompts ,
843843 num_videos_per_prompt = num_videos_per_prompt ,
844844 max_sequence_length = max_sequence_length ,
845845 scale_factor = scale_factor ,
846846 dtype = dtype ,
847847 )
848-
848+
849849 split_idx = batch_size * num_videos_per_prompt
850-
850+
851851 if isinstance (combined_embeds , list ):
852852 prompt_embeds = [state [:split_idx ] for state in combined_embeds ]
853853 negative_prompt_embeds = [state [split_idx :] for state in combined_embeds ]
854854 else :
855855 prompt_embeds = combined_embeds [:split_idx ]
856856 negative_prompt_embeds = combined_embeds [split_idx :]
857-
857+
858858 prompt_attention_mask = combined_mask [:split_idx ]
859859 negative_prompt_attention_mask = combined_mask [split_idx :]
860860 else :
@@ -865,7 +865,7 @@ def encode_prompt(
865865 scale_factor = scale_factor ,
866866 dtype = dtype ,
867867 )
868-
868+
869869 if do_classifier_free_guidance and negative_prompt_embeds is None :
870870 negative_prompt = negative_prompt or ""
871871 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
@@ -1577,95 +1577,80 @@ def run_diffusion_loop(
15771577 scheduler_step ,
15781578 logical_axis_rules ,
15791579):
1580- transformer = nnx .merge (graphdef , state )
1581-
1582- def scan_body (carry , t , model ):
1583- latents , audio_latents , s_state = carry
1584-
1585- with nn_partitioning .axis_rules (logical_axis_rules ):
1586- latents_sharded = latents
1587- audio_latents_sharded = audio_latents
1588-
1589- if not scan_layers :
1590- activation_axis_names = nn .logical_to_mesh_axes (
1591- ("activation_batch" , "activation_length" , "activation_embed" )
1592- )
1593- latents_sharded = jax .lax .with_sharding_constraint (
1594- latents , activation_axis_names
1595- )
1596- audio_latents_sharded = jax .lax .with_sharding_constraint (
1597- audio_latents , activation_axis_names
1598- )
1599-
1600- # Expand timestep to batch size
1601- t_expanded = jnp .expand_dims (t , 0 ).repeat (latents .shape [0 ])
1602-
1603- noise_pred , noise_pred_audio = model (
1604- hidden_states = latents_sharded ,
1605- encoder_hidden_states = video_embeds_sharded ,
1606- timestep = t_expanded ,
1607- encoder_attention_mask = new_attention_mask ,
1608- num_frames = latent_num_frames ,
1609- height = latent_height ,
1610- width = latent_width ,
1611- audio_hidden_states = audio_latents_sharded ,
1612- audio_encoder_hidden_states = audio_embeds_sharded ,
1613- audio_encoder_attention_mask = new_attention_mask ,
1614- fps = fps ,
1615- audio_num_frames = audio_num_frames ,
1616- return_dict = False ,
1617- )
1618-
1619- if guidance_scale > 1.0 :
1620- noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1621- noise_pred = noise_pred_uncond + guidance_scale * (
1622- noise_pred_text - noise_pred_uncond
1623- )
1624- # Audio guidance
1625- (
1626- noise_pred_audio_uncond ,
1627- noise_pred_audio_text ,
1628- ) = jnp .split (noise_pred_audio , 2 , axis = 0 )
1629- noise_pred_audio = (
1630- noise_pred_audio_uncond
1631- + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1632- )
1633-
1634- latents_step = latents [batch_size :]
1635- audio_latents_step = audio_latents [batch_size :]
1636- else :
1637- latents_step = latents
1638- audio_latents_step = audio_latents
1639-
1640- # Step scheduler
1641- latents_step , _ = scheduler_step (
1642- s_state , noise_pred , t , latents_step , return_dict = False
1643- )
1644- latents_step = latents_step .astype (latents .dtype )
1645-
1646- audio_latents_step , _ = scheduler_step (
1647- s_state , noise_pred_audio , t , audio_latents_step , return_dict = False
1648- )
1649- audio_latents_step = audio_latents_step .astype (audio_latents .dtype )
1650-
1651- if guidance_scale > 1.0 :
1652- latents_next = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1653- audio_latents_next = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1654- else :
1655- latents_next = latents_step
1656- audio_latents_next = audio_latents_step
1657-
1658- new_carry = (latents_next , audio_latents_next , s_state )
1659- return new_carry , None
1660-
1661- # Initial carry
1662- initial_carry = (latents_jax , audio_latents_jax , scheduler_state )
1663-
1664- # Run scan
1665- final_carry , _ = nnx .scan (
1666- scan_body ,
1667- in_axes = (nnx .Carry , 0 , None ),
1668- out_axes = (nnx .Carry , 0 ),
1669- )(initial_carry , timesteps_jax , transformer )
1670-
1671- return final_carry [0 ], final_carry [1 ]
1580+ transformer = nnx .merge (graphdef , state )
1581+
1582+ def scan_body (carry , t , model ):
1583+ latents , audio_latents , s_state = carry
1584+
1585+ with nn_partitioning .axis_rules (logical_axis_rules ):
1586+ latents_sharded = latents
1587+ audio_latents_sharded = audio_latents
1588+
1589+ if not scan_layers :
1590+ activation_axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
1591+ latents_sharded = jax .lax .with_sharding_constraint (latents , activation_axis_names )
1592+ audio_latents_sharded = jax .lax .with_sharding_constraint (audio_latents , activation_axis_names )
1593+
1594+ # Expand timestep to batch size
1595+ t_expanded = jnp .expand_dims (t , 0 ).repeat (latents .shape [0 ])
1596+
1597+ noise_pred , noise_pred_audio = model (
1598+ hidden_states = latents_sharded ,
1599+ encoder_hidden_states = video_embeds_sharded ,
1600+ timestep = t_expanded ,
1601+ encoder_attention_mask = new_attention_mask ,
1602+ num_frames = latent_num_frames ,
1603+ height = latent_height ,
1604+ width = latent_width ,
1605+ audio_hidden_states = audio_latents_sharded ,
1606+ audio_encoder_hidden_states = audio_embeds_sharded ,
1607+ audio_encoder_attention_mask = new_attention_mask ,
1608+ fps = fps ,
1609+ audio_num_frames = audio_num_frames ,
1610+ return_dict = False ,
1611+ )
1612+
1613+ if guidance_scale > 1.0 :
1614+ noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1615+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1616+ # Audio guidance
1617+ (
1618+ noise_pred_audio_uncond ,
1619+ noise_pred_audio_text ,
1620+ ) = jnp .split (noise_pred_audio , 2 , axis = 0 )
1621+ noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1622+
1623+ latents_step = latents [batch_size :]
1624+ audio_latents_step = audio_latents [batch_size :]
1625+ else :
1626+ latents_step = latents
1627+ audio_latents_step = audio_latents
1628+
1629+ # Step scheduler
1630+ latents_step , _ = scheduler_step (s_state , noise_pred , t , latents_step , return_dict = False )
1631+ latents_step = latents_step .astype (latents .dtype )
1632+
1633+ audio_latents_step , _ = scheduler_step (s_state , noise_pred_audio , t , audio_latents_step , return_dict = False )
1634+ audio_latents_step = audio_latents_step .astype (audio_latents .dtype )
1635+
1636+ if guidance_scale > 1.0 :
1637+ latents_next = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
1638+ audio_latents_next = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
1639+ else :
1640+ latents_next = latents_step
1641+ audio_latents_next = audio_latents_step
1642+
1643+ new_carry = (latents_next , audio_latents_next , s_state )
1644+ return new_carry , None
1645+
1646+ # Initial carry
1647+ initial_carry = (latents_jax , audio_latents_jax , scheduler_state )
1648+
1649+ # Run scan
1650+ final_carry , _ = nnx .scan (
1651+ scan_body ,
1652+ in_axes = (nnx .Carry , 0 , None ),
1653+ out_axes = (nnx .Carry , 0 ),
1654+ )(initial_carry , timesteps_jax , transformer )
1655+
1656+ return final_carry [0 ], final_carry [1 ]
0 commit comments