@@ -1381,16 +1381,16 @@ def __call__(
13811381 negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13821382
13831383 if isinstance (prompt_embeds_jax , list ):
1384- prompt_embeds_jax = [jnp .concatenate ([n , p , p , p ], axis = 0 ) for n , p in zip (negative_prompt_embeds_jax , prompt_embeds_jax )]
1384+ prompt_embeds_jax = [jnp .concatenate ([n , p , p ], axis = 0 ) for n , p in zip (negative_prompt_embeds_jax , prompt_embeds_jax )]
13851385 else :
1386- prompt_embeds_jax = jnp .concatenate ([negative_prompt_embeds_jax , prompt_embeds_jax , prompt_embeds_jax , prompt_embeds_jax ], axis = 0 )
1386+ prompt_embeds_jax = jnp .concatenate ([negative_prompt_embeds_jax , prompt_embeds_jax , prompt_embeds_jax ], axis = 0 )
13871387
1388- prompt_attention_mask_jax = jnp .concatenate ([negative_prompt_attention_mask_jax , prompt_attention_mask_jax , prompt_attention_mask_jax , prompt_attention_mask_jax ], axis = 0 )
1389- latents_jax = jnp .concatenate ([latents_jax ] * 4 , axis = 0 )
1390- audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 4 , axis = 0 )
1388+ prompt_attention_mask_jax = jnp .concatenate ([negative_prompt_attention_mask_jax , prompt_attention_mask_jax , prompt_attention_mask_jax ], axis = 0 )
1389+ latents_jax = jnp .concatenate ([latents_jax ] * 3 , axis = 0 )
1390+ audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 3 , axis = 0 )
13911391
13921392 N = latents .shape [0 ]
1393- perturbation_mask = jnp .concatenate ([jnp .ones ((2 * N , 1 , 1 ), dtype = dtype ), jnp .zeros ((N , 1 , 1 ), dtype = dtype ), jnp . ones (( N , 1 , 1 ), dtype = dtype ) ], axis = 0 )
1393+ perturbation_mask = jnp .concatenate ([jnp .ones ((2 * N , 1 , 1 ), dtype = dtype ), jnp .zeros ((N , 1 , 1 ), dtype = dtype )], axis = 0 )
13941394
13951395 elif do_cfg :
13961396 negative_prompt_embeds_jax = negative_prompt_embeds
@@ -1527,7 +1527,7 @@ def convert_to_vel(lat, x0):
15271527 return (lat - x0 ) / sigma_t
15281528
15291529 if do_cfg and do_stg :
1530- noise_pred_uncond , noise_pred_text , noise_pred_perturb , _ = jnp .split (noise_pred , 4 , axis = 0 )
1530+ noise_pred_uncond , noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 3 , axis = 0 )
15311531
15321532 # Convert to x0
15331533 x0_uncond = convert_to_x0 (latents_step , noise_pred_uncond )
@@ -1548,7 +1548,7 @@ def convert_to_vel(lat, x0):
15481548 noise_pred = convert_to_vel (latents_step , x0_combined )
15491549
15501550 # Audio guidance
1551- noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb , _ = jnp .split (noise_pred_audio , 4 , axis = 0 )
1551+ noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 3 , axis = 0 )
15521552
15531553 x0_audio_uncond = convert_to_x0 (audio_latents_step , noise_pred_audio_uncond )
15541554 x0_audio_text = convert_to_x0 (audio_latents_step , noise_pred_audio_text )
0 commit comments