@@ -1508,39 +1508,14 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15081508 use_cross_timestep = use_cross_timestep ,
15091509 )
15101510
1511+ do_cfg = guidance_scale > 1.0
15111512 do_stg = stg_scale > 0.0
1512-
1513- if guidance_scale > 1.0 and do_stg :
1514- noise_pred_uncond , noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 3 , axis = 0 )
1515- noise_pred = (
1516- noise_pred_uncond
1517- + guidance_scale * (noise_pred_text - noise_pred_uncond )
1518- + stg_scale * (noise_pred_text - noise_pred_perturb )
1519- )
1520- # Audio guidance
1521- noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 3 , axis = 0 )
1522- noise_pred_audio = (
1523- noise_pred_audio_uncond
1524- + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1525- + stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb )
1526- )
1527- elif guidance_scale > 1.0 :
1528- noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1529- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1530- # Audio guidance
1531- noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
1532- noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond )
1533- elif do_stg :
1534- noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 2 , axis = 0 )
1535- noise_pred = noise_pred_text + stg_scale * (noise_pred_text - noise_pred_perturb )
1536-
1537- noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 2 , axis = 0 )
1538- noise_pred_audio = noise_pred_audio_text + stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb )
1513+ sigma_t = sigmas [i ]
15391514
15401515 # Extract latents_step based on stacking strategy
15411516 if do_cfg and do_stg :
1542- latents_step = latents_jax [batch_size : 2 * batch_size ]
1543- audio_latents_step = audio_latents_jax [batch_size : 2 * batch_size ]
1517+ latents_step = latents_jax [batch_size : 2 * batch_size ]
1518+ audio_latents_step = audio_latents_jax [batch_size : 2 * batch_size ]
15441519 elif do_cfg :
15451520 latents_step = latents_jax [batch_size :]
15461521 audio_latents_step = audio_latents_jax [batch_size :]
@@ -1551,6 +1526,96 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15511526 latents_step = latents_jax
15521527 audio_latents_step = audio_latents_jax
15531528
1529+ # Helper to convert velocity to x0
1530+ def convert_to_x0 (lat , vel ):
1531+ return lat - vel * sigma_t
1532+
1533+ # Helper to convert x0 back to velocity
1534+ def convert_to_vel (lat , x0 ):
1535+ return (lat - x0 ) / sigma_t
1536+
1537+ if do_cfg and do_stg :
1538+ noise_pred_uncond , noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 3 , axis = 0 )
1539+
1540+ # Convert to x0
1541+ x0_uncond = convert_to_x0 (latents_step , noise_pred_uncond )
1542+ x0_text = convert_to_x0 (latents_step , noise_pred_text )
1543+ x0_perturb = convert_to_x0 (latents_step , noise_pred_perturb )
1544+
1545+ # Delta formulation
1546+ cfg_delta = (guidance_scale - 1 ) * (x0_text - x0_uncond )
1547+ stg_delta = stg_scale * (x0_text - x0_perturb )
1548+
1549+ x0_combined = x0_text + cfg_delta + stg_delta
1550+
1551+ # Apply guidance rescale if needed
1552+ if guidance_rescale > 0 :
1553+ x0_combined = rescale_noise_cfg (x0_combined , x0_text , guidance_rescale = guidance_rescale )
1554+
1555+ # Convert back to velocity
1556+ noise_pred = convert_to_vel (latents_step , x0_combined )
1557+
1558+ # Audio guidance
1559+ noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 3 , axis = 0 )
1560+
1561+ x0_audio_uncond = convert_to_x0 (audio_latents_step , noise_pred_audio_uncond )
1562+ x0_audio_text = convert_to_x0 (audio_latents_step , noise_pred_audio_text )
1563+ x0_audio_perturb = convert_to_x0 (audio_latents_step , noise_pred_audio_perturb )
1564+
1565+ cfg_audio_delta = (audio_guidance_scale - 1 if audio_guidance_scale is not None else guidance_scale - 1 ) * (x0_audio_text - x0_audio_uncond )
1566+ stg_audio_delta = (audio_stg_scale if audio_stg_scale is not None else stg_scale ) * (x0_audio_text - x0_audio_perturb )
1567+
1568+ x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta
1569+
1570+ noise_pred_audio = convert_to_vel (audio_latents_step , x0_audio_combined )
1571+
1572+ elif do_cfg :
1573+ noise_pred_uncond , noise_pred_text = jnp .split (noise_pred , 2 , axis = 0 )
1574+
1575+ x0_uncond = convert_to_x0 (latents_step , noise_pred_uncond )
1576+ x0_text = convert_to_x0 (latents_step , noise_pred_text )
1577+
1578+ cfg_delta = (guidance_scale - 1 ) * (x0_text - x0_uncond )
1579+ x0_combined = x0_text + cfg_delta
1580+
1581+ if guidance_rescale > 0 :
1582+ x0_combined = rescale_noise_cfg (x0_combined , x0_text , guidance_rescale = guidance_rescale )
1583+
1584+ noise_pred = convert_to_vel (latents_step , x0_combined )
1585+
1586+ # Audio guidance
1587+ noise_pred_audio_uncond , noise_pred_audio_text = jnp .split (noise_pred_audio , 2 , axis = 0 )
1588+
1589+ x0_audio_uncond = convert_to_x0 (audio_latents_step , noise_pred_audio_uncond )
1590+ x0_audio_text = convert_to_x0 (audio_latents_step , noise_pred_audio_text )
1591+
1592+ cfg_audio_delta = (audio_guidance_scale - 1 if audio_guidance_scale is not None else guidance_scale - 1 ) * (x0_audio_text - x0_audio_uncond )
1593+ x0_audio_combined = x0_audio_text + cfg_audio_delta
1594+
1595+ noise_pred_audio = convert_to_vel (audio_latents_step , x0_audio_combined )
1596+
1597+ elif do_stg :
1598+ noise_pred_text , noise_pred_perturb = jnp .split (noise_pred , 2 , axis = 0 )
1599+
1600+ x0_text = convert_to_x0 (latents_step , noise_pred_text )
1601+ x0_perturb = convert_to_x0 (latents_step , noise_pred_perturb )
1602+
1603+ stg_delta = stg_scale * (x0_text - x0_perturb )
1604+ x0_combined = x0_text + stg_delta
1605+
1606+ noise_pred = convert_to_vel (latents_step , x0_combined )
1607+
1608+ # Audio guidance
1609+ noise_pred_audio_text , noise_pred_audio_perturb = jnp .split (noise_pred_audio , 2 , axis = 0 )
1610+
1611+ x0_audio_text = convert_to_x0 (audio_latents_step , noise_pred_audio_text )
1612+ x0_audio_perturb = convert_to_x0 (audio_latents_step , noise_pred_audio_perturb )
1613+
1614+ stg_audio_delta = (audio_stg_scale if audio_stg_scale is not None else stg_scale ) * (x0_audio_text - x0_audio_perturb )
1615+ x0_audio_combined = x0_audio_text + stg_audio_delta
1616+
1617+ noise_pred_audio = convert_to_vel (audio_latents_step , x0_audio_combined )
1618+
15541619 # Step
15551620 latents_step , _ = self .scheduler .step (scheduler_state , noise_pred , t , latents_step , return_dict = False )
15561621 audio_latents_step , _ = self .scheduler .step (
0 commit comments