@@ -1346,9 +1346,9 @@ def __call__(
13461346 audio_embeds_sharded = jax .device_put (audio_embeds , spec )
13471347
13481348 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
1349-
1349+
13501350 scan_diffusion_loop = getattr (self .config , "scan_diffusion_loop" , True )
1351-
1351+
13521352 if scan_diffusion_loop :
13531353 latents_jax , audio_latents_jax = run_diffusion_loop (
13541354 graphdef ,
@@ -1375,7 +1375,7 @@ def __call__(
13751375 # Old Python loop path
13761376 latents_jax = latents_jax .astype (jnp .float32 )
13771377 audio_latents_jax = audio_latents_jax .astype (jnp .float32 )
1378-
1378+
13791379 for t in timesteps_jax :
13801380 noise_pred , noise_pred_audio = transformer_forward_pass (
13811381 graphdef ,
@@ -1395,26 +1395,28 @@ def __call__(
13951395 audio_num_frames ,
13961396 frame_rate ,
13971397 )
1398-
1398+
13991399 if guidance_scale > 1.0 :
14001400 noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
14011401 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1402-
1402+
14031403 noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
14041404 noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1405-
1405+
14061406 latents_step = latents_jax [batch_size :]
14071407 audio_latents_step = audio_latents_jax [batch_size :]
14081408 else :
14091409 latents_step = latents_jax
14101410 audio_latents_step = audio_latents_jax
1411-
1411+
14121412 latents_step , _ = self .scheduler .step (scheduler_state , noise_pred , t , latents_step , return_dict = False )
14131413 latents_step = latents_step .astype (jnp .float32 )
1414-
1415- audio_latents_step , _ = self .scheduler .step (scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False )
1414+
1415+ audio_latents_step , _ = self .scheduler .step (
1416+ scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False
1417+ )
14161418 audio_latents_step = audio_latents_step .astype (jnp .float32 )
1417-
1419+
14181420 if guidance_scale > 1.0 :
14191421 latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
14201422 audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
0 commit comments