Skip to content

Commit 9f2d778

Browse files
committed
fix
1 parent a0795b1 commit 9f2d778

1 file changed

Lines changed: 21 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,8 +1472,16 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14721472
noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 2, axis=0)
14731473
noise_pred_audio = noise_pred_audio_text + self.config.stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
14741474

1475+
# Extract latents_step based on stacking strategy
1476+
if do_cfg and do_stg:
1477+
latents_step = latents_jax[batch_size:2*batch_size]
1478+
audio_latents_step = audio_latents_jax[batch_size:2*batch_size]
1479+
elif do_cfg:
14751480
latents_step = latents_jax[batch_size:]
14761481
audio_latents_step = audio_latents_jax[batch_size:]
1482+
elif do_stg:
1483+
latents_step = latents_jax[:batch_size]
1484+
audio_latents_step = audio_latents_jax[:batch_size]
14771485
else:
14781486
latents_step = latents_jax
14791487
audio_latents_step = audio_latents_jax
@@ -1484,17 +1492,28 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14841492
scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False
14851493
)
14861494

1487-
if guidance_scale > 1.0:
1495+
# Re-stack based on strategy for next iteration
1496+
if do_cfg and do_stg:
1497+
latents_jax = jnp.concatenate([latents_step] * 3, axis=0)
1498+
audio_latents_jax = jnp.concatenate([audio_latents_step] * 3, axis=0)
1499+
elif do_cfg or do_stg:
14881500
latents_jax = jnp.concatenate([latents_step] * 2, axis=0)
14891501
audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0)
14901502
else:
14911503
latents_jax = latents_step
14921504
audio_latents_jax = audio_latents_step
14931505

14941506
# 8. Decode Latents
1495-
if guidance_scale > 1.0:
1507+
# 8. Decode Latents - Extract conditional branch
1508+
if do_cfg and do_stg:
1509+
latents_jax = latents_jax[batch_size:2*batch_size]
1510+
audio_latents_jax = audio_latents_jax[batch_size:2*batch_size]
1511+
elif do_cfg:
14961512
latents_jax = latents_jax[batch_size:]
14971513
audio_latents_jax = audio_latents_jax[batch_size:]
1514+
elif do_stg:
1515+
latents_jax = latents_jax[:batch_size]
1516+
audio_latents_jax = audio_latents_jax[:batch_size]
14981517

14991518
# Unpack and Denormalize Video
15001519
latents = self._unpack_latents(

0 commit comments

Comments
 (0)