Skip to content

Commit a914786

Browse files
committed
padding batch dim by having pos prompt as dummy
1 parent 6650242 commit a914786

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,16 +1381,16 @@ def __call__(
13811381
negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13821382

13831383
if isinstance(prompt_embeds_jax, list):
1384-
prompt_embeds_jax = [jnp.concatenate([n, p, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
1384+
prompt_embeds_jax = [jnp.concatenate([n, p, p, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
13851385
else:
1386-
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax], axis=0)
1386+
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax], axis=0)
13871387

1388-
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1389-
latents_jax = jnp.concatenate([latents_jax] * 3, axis=0)
1390-
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 3, axis=0)
1388+
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1389+
latents_jax = jnp.concatenate([latents_jax] * 4, axis=0)
1390+
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 4, axis=0)
13911391

13921392
N = latents.shape[0]
1393-
perturbation_mask = jnp.concatenate([jnp.ones((2 * N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype)], axis=0)
1393+
perturbation_mask = jnp.concatenate([jnp.ones((2 * N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype), jnp.ones((N, 1, 1), dtype=dtype)], axis=0)
13941394

13951395
elif do_cfg:
13961396
negative_prompt_embeds_jax = negative_prompt_embeds
@@ -1527,7 +1527,7 @@ def convert_to_vel(lat, x0):
15271527
return (lat - x0) / sigma_t
15281528

15291529
if do_cfg and do_stg:
1530-
noise_pred_uncond, noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 3, axis=0)
1530+
noise_pred_uncond, noise_pred_text, noise_pred_perturb, _ = jnp.split(noise_pred, 4, axis=0)
15311531

15321532
# Convert to x0
15331533
x0_uncond = convert_to_x0(latents_step, noise_pred_uncond)

0 commit comments

Comments
 (0)