@@ -1472,8 +1472,16 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14721472 noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 2 , axis = 0 )
14731473 noise_pred_audio = noise_pred_audio_text + self .config .stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb )
14741474
1475+ # Extract latents_step based on stacking strategy
1476+ if do_cfg and do_stg :
1477+ latents_step = latents_jax [batch_size :2 * batch_size ]
1478+ audio_latents_step = audio_latents_jax [batch_size :2 * batch_size ]
1479+ elif do_cfg :
14751480 latents_step = latents_jax [batch_size :]
14761481 audio_latents_step = audio_latents_jax [batch_size :]
1482+ elif do_stg :
1483+ latents_step = latents_jax [:batch_size ]
1484+ audio_latents_step = audio_latents_jax [:batch_size ]
14771485 else :
14781486 latents_step = latents_jax
14791487 audio_latents_step = audio_latents_jax
@@ -1484,17 +1492,28 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14841492 scheduler_state , noise_pred_audio , t , audio_latents_step , return_dict = False
14851493 )
14861494
1487- if guidance_scale > 1.0 :
1495+ # Re-stack based on strategy for next iteration
1496+ if do_cfg and do_stg :
1497+ latents_jax = jnp .concatenate ([latents_step ] * 3 , axis = 0 )
1498+ audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 3 , axis = 0 )
1499+ elif do_cfg or do_stg :
14881500 latents_jax = jnp .concatenate ([latents_step ] * 2 , axis = 0 )
14891501 audio_latents_jax = jnp .concatenate ([audio_latents_step ] * 2 , axis = 0 )
14901502 else :
14911503 latents_jax = latents_step
14921504 audio_latents_jax = audio_latents_step
14931505
14941506 # 8. Decode Latents
1495- if guidance_scale > 1.0 :
1507+ # 8. Decode Latents - Extract conditional branch
1508+ if do_cfg and do_stg :
1509+ latents_jax = latents_jax [batch_size :2 * batch_size ]
1510+ audio_latents_jax = audio_latents_jax [batch_size :2 * batch_size ]
1511+ elif do_cfg :
14961512 latents_jax = latents_jax [batch_size :]
14971513 audio_latents_jax = audio_latents_jax [batch_size :]
1514+ elif do_stg :
1515+ latents_jax = latents_jax [:batch_size ]
1516+ audio_latents_jax = audio_latents_jax [:batch_size ]
14981517
14991518 # Unpack and Denormalize Video
15001519 latents = self ._unpack_latents (
0 commit comments