Skip to content

Commit c2d9473

Browse files
committed
prediction space conversion
1 parent da0a909 commit c2d9473

1 file changed

Lines changed: 94 additions & 29 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,39 +1508,14 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15081508
use_cross_timestep=use_cross_timestep,
15091509
)
15101510

1511+
do_cfg = guidance_scale > 1.0
15111512
do_stg = stg_scale > 0.0
1512-
1513-
if guidance_scale > 1.0 and do_stg:
1514-
noise_pred_uncond, noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 3, axis=0)
1515-
noise_pred = (
1516-
noise_pred_uncond
1517-
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
1518-
+ stg_scale * (noise_pred_text - noise_pred_perturb)
1519-
)
1520-
# Audio guidance
1521-
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 3, axis=0)
1522-
noise_pred_audio = (
1523-
noise_pred_audio_uncond
1524-
+ guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1525-
+ stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
1526-
)
1527-
elif guidance_scale > 1.0:
1528-
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1529-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1530-
# Audio guidance
1531-
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1532-
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1533-
elif do_stg:
1534-
noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 2, axis=0)
1535-
noise_pred = noise_pred_text + stg_scale * (noise_pred_text - noise_pred_perturb)
1536-
1537-
noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 2, axis=0)
1538-
noise_pred_audio = noise_pred_audio_text + stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
1513+
sigma_t = sigmas[i]
15391514

15401515
# Extract latents_step based on stacking strategy
15411516
if do_cfg and do_stg:
1542-
latents_step = latents_jax[batch_size:2*batch_size]
1543-
audio_latents_step = audio_latents_jax[batch_size:2*batch_size]
1517+
latents_step = latents_jax[batch_size : 2 * batch_size]
1518+
audio_latents_step = audio_latents_jax[batch_size : 2 * batch_size]
15441519
elif do_cfg:
15451520
latents_step = latents_jax[batch_size:]
15461521
audio_latents_step = audio_latents_jax[batch_size:]
@@ -1551,6 +1526,96 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15511526
latents_step = latents_jax
15521527
audio_latents_step = audio_latents_jax
15531528

1529+
# Helper to convert velocity to x0
1530+
def convert_to_x0(lat, vel):
1531+
return lat - vel * sigma_t
1532+
1533+
# Helper to convert x0 back to velocity
1534+
def convert_to_vel(lat, x0):
1535+
return (lat - x0) / sigma_t
1536+
1537+
if do_cfg and do_stg:
1538+
noise_pred_uncond, noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 3, axis=0)
1539+
1540+
# Convert to x0
1541+
x0_uncond = convert_to_x0(latents_step, noise_pred_uncond)
1542+
x0_text = convert_to_x0(latents_step, noise_pred_text)
1543+
x0_perturb = convert_to_x0(latents_step, noise_pred_perturb)
1544+
1545+
# Delta formulation
1546+
cfg_delta = (guidance_scale - 1) * (x0_text - x0_uncond)
1547+
stg_delta = stg_scale * (x0_text - x0_perturb)
1548+
1549+
x0_combined = x0_text + cfg_delta + stg_delta
1550+
1551+
# Apply guidance rescale if needed
1552+
if guidance_rescale > 0:
1553+
x0_combined = rescale_noise_cfg(x0_combined, x0_text, guidance_rescale=guidance_rescale)
1554+
1555+
# Convert back to velocity
1556+
noise_pred = convert_to_vel(latents_step, x0_combined)
1557+
1558+
# Audio guidance
1559+
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 3, axis=0)
1560+
1561+
x0_audio_uncond = convert_to_x0(audio_latents_step, noise_pred_audio_uncond)
1562+
x0_audio_text = convert_to_x0(audio_latents_step, noise_pred_audio_text)
1563+
x0_audio_perturb = convert_to_x0(audio_latents_step, noise_pred_audio_perturb)
1564+
1565+
cfg_audio_delta = (audio_guidance_scale - 1 if audio_guidance_scale is not None else guidance_scale - 1) * (x0_audio_text - x0_audio_uncond)
1566+
stg_audio_delta = (audio_stg_scale if audio_stg_scale is not None else stg_scale) * (x0_audio_text - x0_audio_perturb)
1567+
1568+
x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta
1569+
1570+
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined)
1571+
1572+
elif do_cfg:
1573+
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
1574+
1575+
x0_uncond = convert_to_x0(latents_step, noise_pred_uncond)
1576+
x0_text = convert_to_x0(latents_step, noise_pred_text)
1577+
1578+
cfg_delta = (guidance_scale - 1) * (x0_text - x0_uncond)
1579+
x0_combined = x0_text + cfg_delta
1580+
1581+
if guidance_rescale > 0:
1582+
x0_combined = rescale_noise_cfg(x0_combined, x0_text, guidance_rescale=guidance_rescale)
1583+
1584+
noise_pred = convert_to_vel(latents_step, x0_combined)
1585+
1586+
# Audio guidance
1587+
noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0)
1588+
1589+
x0_audio_uncond = convert_to_x0(audio_latents_step, noise_pred_audio_uncond)
1590+
x0_audio_text = convert_to_x0(audio_latents_step, noise_pred_audio_text)
1591+
1592+
cfg_audio_delta = (audio_guidance_scale - 1 if audio_guidance_scale is not None else guidance_scale - 1) * (x0_audio_text - x0_audio_uncond)
1593+
x0_audio_combined = x0_audio_text + cfg_audio_delta
1594+
1595+
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined)
1596+
1597+
elif do_stg:
1598+
noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 2, axis=0)
1599+
1600+
x0_text = convert_to_x0(latents_step, noise_pred_text)
1601+
x0_perturb = convert_to_x0(latents_step, noise_pred_perturb)
1602+
1603+
stg_delta = stg_scale * (x0_text - x0_perturb)
1604+
x0_combined = x0_text + stg_delta
1605+
1606+
noise_pred = convert_to_vel(latents_step, x0_combined)
1607+
1608+
# Audio guidance
1609+
noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 2, axis=0)
1610+
1611+
x0_audio_text = convert_to_x0(audio_latents_step, noise_pred_audio_text)
1612+
x0_audio_perturb = convert_to_x0(audio_latents_step, noise_pred_audio_perturb)
1613+
1614+
stg_audio_delta = (audio_stg_scale if audio_stg_scale is not None else stg_scale) * (x0_audio_text - x0_audio_perturb)
1615+
x0_audio_combined = x0_audio_text + stg_audio_delta
1616+
1617+
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined)
1618+
15541619
# Step
15551620
latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False)
15561621
audio_latents_step, _ = self.scheduler.step(

0 commit comments

Comments
 (0)