@@ -1382,16 +1382,16 @@ def __call__(
13821382 negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13831383
13841384 if isinstance (prompt_embeds_jax , list ):
1385- prompt_embeds_jax = [jnp .concatenate ([n , p , p ], axis = 0 ) for n , p in zip (negative_prompt_embeds_jax , prompt_embeds_jax )]
1385+ prompt_embeds_jax = [jnp .concatenate ([n , p , p , p ], axis = 0 ) for n , p in zip (negative_prompt_embeds_jax , prompt_embeds_jax )]
13861386 else :
1387- prompt_embeds_jax = jnp .concatenate ([negative_prompt_embeds_jax , prompt_embeds_jax , prompt_embeds_jax ], axis = 0 )
1387+ prompt_embeds_jax = jnp .concatenate ([negative_prompt_embeds_jax , prompt_embeds_jax , prompt_embeds_jax , prompt_embeds_jax ], axis = 0 )
13881388
1389- prompt_attention_mask_jax = jnp .concatenate ([negative_prompt_attention_mask_jax , prompt_attention_mask_jax , prompt_attention_mask_jax ], axis = 0 )
1390- latents_jax = jnp .concatenate ([latents_jax ] * 3 , axis = 0 )
1391- audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 3 , axis = 0 )
1389+ prompt_attention_mask_jax = jnp .concatenate ([negative_prompt_attention_mask_jax , prompt_attention_mask_jax , prompt_attention_mask_jax , prompt_attention_mask_jax ], axis = 0 )
1390+ latents_jax = jnp .concatenate ([latents_jax ] * 4 , axis = 0 )
1391+ audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 4 , axis = 0 )
13921392
13931393 N = latents .shape [0 ]
1394- perturbation_mask = jnp .concatenate ([jnp .ones ((2 * N , 1 , 1 ), dtype = dtype ), jnp .zeros ((N , 1 , 1 ), dtype = dtype )], axis = 0 )
1394+ perturbation_mask = jnp .concatenate ([jnp .ones ((2 * N , 1 , 1 ), dtype = dtype ), jnp .zeros ((N , 1 , 1 ), dtype = dtype ), jnp . ones (( N , 1 , 1 ), dtype = dtype ) ], axis = 0 )
13951395
13961396 elif do_cfg :
13971397 negative_prompt_embeds_jax = negative_prompt_embeds
@@ -1528,18 +1528,20 @@ def convert_to_vel(lat, x0):
15281528 return (lat - x0 ) / sigma_t
15291529
15301530 if do_cfg and do_stg :
1531- noise_pred_uncond , noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 3 , axis = 0 )
1531+ noise_pred_uncond , noise_pred_text , noise_pred_perturb , noise_pred_isolated = jnp .split (noise_pred , 4 , axis = 0 )
15321532
15331533 # Convert to x0
15341534 x0_uncond = convert_to_x0 (latents_step , noise_pred_uncond )
15351535 x0_text = convert_to_x0 (latents_step , noise_pred_text )
15361536 x0_perturb = convert_to_x0 (latents_step , noise_pred_perturb )
1537+ x0_isolated = convert_to_x0 (latents_step , noise_pred_isolated )
15371538
15381539 # Delta formulation
15391540 cfg_delta = (guidance_scale - 1 ) * (x0_text - x0_uncond )
15401541 stg_delta = stg_scale * (x0_text - x0_perturb )
1542+ video_modality_delta = (modality_scale - 1 ) * (x0_text - x0_isolated )
15411543
1542- x0_combined = x0_text + cfg_delta + stg_delta
1544+ x0_combined = x0_text + cfg_delta + stg_delta + video_modality_delta
15431545
15441546 # Apply guidance rescale if needed
15451547 if guidance_rescale > 0 :
@@ -1549,16 +1551,18 @@ def convert_to_vel(lat, x0):
15491551 noise_pred = convert_to_vel (latents_step , x0_combined )
15501552
15511553 # Audio guidance
1552- noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 3 , axis = 0 )
1554+ noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb , noise_pred_audio_isolated = jnp .split (noise_pred_audio , 4 , axis = 0 )
15531555
15541556 x0_audio_uncond = convert_to_x0 (audio_latents_step , noise_pred_audio_uncond )
15551557 x0_audio_text = convert_to_x0 (audio_latents_step , noise_pred_audio_text )
15561558 x0_audio_perturb = convert_to_x0 (audio_latents_step , noise_pred_audio_perturb )
1559+ x0_audio_isolated = convert_to_x0 (audio_latents_step , noise_pred_audio_isolated )
15571560
15581561 cfg_audio_delta = (audio_guidance_scale - 1 if audio_guidance_scale is not None else guidance_scale - 1 ) * (x0_audio_text - x0_audio_uncond )
15591562 stg_audio_delta = (audio_stg_scale if audio_stg_scale is not None else stg_scale ) * (x0_audio_text - x0_audio_perturb )
1563+ audio_modality_delta = (audio_modality_scale - 1 if audio_modality_scale is not None else modality_scale - 1 ) * (x0_audio_text - x0_audio_isolated )
15601564
1561- x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta
1565+ x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta + audio_modality_delta
15621566
15631567 noise_pred_audio = convert_to_vel (audio_latents_step , x0_audio_combined )
15641568
@@ -1789,13 +1793,17 @@ def transformer_forward_pass(
17891793 else :
17901794 audio_sigma = jnp .expand_dims (audio_sigma , 0 ).repeat (latents .shape [0 ])
17911795
1796+ N = latents .shape [0 ] // 4
1797+ modality_mask = jnp .concatenate ([jnp .ones ((3 * N , 1 , 1 , 1 ), dtype = latents .dtype ), jnp .zeros ((N , 1 , 1 , 1 ), dtype = latents .dtype )], axis = 0 )
1798+
17921799 noise_pred , noise_pred_audio = transformer (
17931800 hidden_states = latents ,
17941801 encoder_hidden_states = encoder_hidden_states ,
17951802 timestep = timestep ,
17961803 sigma = sigma ,
17971804 audio_sigma = audio_sigma ,
17981805 encoder_attention_mask = encoder_attention_mask ,
1806+ modality_mask = modality_mask ,
17991807 num_frames = latent_num_frames ,
18001808 height = latent_height ,
18011809 width = latent_width ,
0 commit comments