Skip to content

Commit 7b5d50a

Browse files
committed
revert padding across batch dim
1 parent af53543 commit 7b5d50a

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,16 +1381,16 @@ def __call__(
13811381
negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13821382

13831383
if isinstance(prompt_embeds_jax, list):
1384-
prompt_embeds_jax = [jnp.concatenate([n, p, p, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
1384+
prompt_embeds_jax = [jnp.concatenate([n, p, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
13851385
else:
1386-
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax], axis=0)
1386+
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax], axis=0)
13871387

1388-
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1389-
latents_jax = jnp.concatenate([latents_jax] * 4, axis=0)
1390-
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 4, axis=0)
1388+
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1389+
latents_jax = jnp.concatenate([latents_jax] * 3, axis=0)
1390+
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 3, axis=0)
13911391

13921392
N = latents.shape[0]
1393-
perturbation_mask = jnp.concatenate([jnp.ones((2 * N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype), jnp.ones((N, 1, 1), dtype=dtype)], axis=0)
1393+
perturbation_mask = jnp.concatenate([jnp.ones((2 * N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype)], axis=0)
13941394

13951395
elif do_cfg:
13961396
negative_prompt_embeds_jax = negative_prompt_embeds
@@ -1527,7 +1527,7 @@ def convert_to_vel(lat, x0):
15271527
return (lat - x0) / sigma_t
15281528

15291529
if do_cfg and do_stg:
1530-
noise_pred_uncond, noise_pred_text, noise_pred_perturb, _ = jnp.split(noise_pred, 4, axis=0)
1530+
noise_pred_uncond, noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 3, axis=0)
15311531

15321532
# Convert to x0
15331533
x0_uncond = convert_to_x0(latents_step, noise_pred_uncond)
@@ -1548,7 +1548,7 @@ def convert_to_vel(lat, x0):
15481548
noise_pred = convert_to_vel(latents_step, x0_combined)
15491549

15501550
# Audio guidance
1551-
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb, _ = jnp.split(noise_pred_audio, 4, axis=0)
1551+
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 3, axis=0)
15521552

15531553
x0_audio_uncond = convert_to_x0(audio_latents_step, noise_pred_audio_uncond)
15541554
x0_audio_text = convert_to_x0(audio_latents_step, noise_pred_audio_text)

0 commit comments

Comments
 (0)