Skip to content

Commit 6ba8606

Browse files
committed
modality mask
1 parent be522c6 commit 6ba8606

2 files changed

Lines changed: 32 additions & 14 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def __call__(
364364
temb_ca_audio_gate: jax.Array,
365365
temb_prompt: Optional[jax.Array] = None,
366366
temb_prompt_audio: Optional[jax.Array] = None,
367+
modality_mask: Optional[jax.Array] = None,
367368
# RoPE
368369
video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
369370
audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
@@ -553,6 +554,8 @@ def __call__(
553554
k_rotary_emb=ca_audio_rotary_emb,
554555
attention_mask=a2v_cross_attention_mask,
555556
)
557+
if modality_mask is not None:
558+
a2v_attn_hidden_states = a2v_attn_hidden_states * modality_mask
556559
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
557560

558561
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
@@ -567,6 +570,8 @@ def __call__(
567570
k_rotary_emb=ca_video_rotary_emb,
568571
attention_mask=v2a_cross_attention_mask,
569572
)
573+
if modality_mask is not None:
574+
v2a_attn_hidden_states = v2a_attn_hidden_states * modality_mask
570575
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
571576

572577
# 4. Feedforward
@@ -1028,6 +1033,7 @@ def __call__(
10281033
audio_coords: Optional[jax.Array] = None,
10291034
attention_kwargs: Optional[Dict[str, Any]] = None,
10301035
use_cross_timestep: bool = False,
1036+
modality_mask: Optional[jax.Array] = None,
10311037
return_dict: bool = True,
10321038
perturbation_mask: Optional[jax.Array] = None,
10331039
) -> Any:
@@ -1171,10 +1177,14 @@ def scan_fn(carry, block_and_mask):
11711177
temb_prompt_audio=temb_prompt_audio,
11721178
video_rotary_emb=video_rotary_emb,
11731179
audio_rotary_emb=audio_rotary_emb,
1174-
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1175-
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1176-
encoder_attention_mask=encoder_attention_mask,
1177-
audio_encoder_attention_mask=audio_encoder_attention_mask,
1180+
ca_video_rotary_emb=ca_video_rotary_emb,
1181+
ca_audio_rotary_emb=ca_audio_rotary_emb,
1182+
a2v_cross_attention_mask=a2v_cross_attention_mask,
1183+
v2a_cross_attention_mask=v2a_cross_attention_mask,
1184+
attention_mask=mask,
1185+
attention_kwargs=attention_kwargs,
1186+
use_cross_timestep=use_cross_timestep,
1187+
modality_mask=modality_mask,
11781188
)
11791189
return (
11801190
hidden_states_out.astype(hidden_states.dtype),

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,16 +1382,16 @@ def __call__(
13821382
negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13831383

13841384
if isinstance(prompt_embeds_jax, list):
1385-
prompt_embeds_jax = [jnp.concatenate([n, p, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
1385+
prompt_embeds_jax = [jnp.concatenate([n, p, p, p], axis=0) for n, p in zip(negative_prompt_embeds_jax, prompt_embeds_jax)]
13861386
else:
1387-
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax], axis=0)
1387+
prompt_embeds_jax = jnp.concatenate([negative_prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax, prompt_embeds_jax], axis=0)
13881388

1389-
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1390-
latents_jax = jnp.concatenate([latents_jax] * 3, axis=0)
1391-
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 3, axis=0)
1389+
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
1390+
latents_jax = jnp.concatenate([latents_jax] * 4, axis=0)
1391+
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 4, axis=0)
13921392

13931393
N = latents.shape[0]
1394-
perturbation_mask = jnp.concatenate([jnp.ones((2 * N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype)], axis=0)
1394+
perturbation_mask = jnp.concatenate([jnp.ones((2 * N, 1, 1), dtype=dtype), jnp.zeros((N, 1, 1), dtype=dtype), jnp.ones((N, 1, 1), dtype=dtype)], axis=0)
13951395

13961396
elif do_cfg:
13971397
negative_prompt_embeds_jax = negative_prompt_embeds
@@ -1528,18 +1528,20 @@ def convert_to_vel(lat, x0):
15281528
return (lat - x0) / sigma_t
15291529

15301530
if do_cfg and do_stg:
1531-
noise_pred_uncond, noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 3, axis=0)
1531+
noise_pred_uncond, noise_pred_text, noise_pred_perturb, noise_pred_isolated = jnp.split(noise_pred, 4, axis=0)
15321532

15331533
# Convert to x0
15341534
x0_uncond = convert_to_x0(latents_step, noise_pred_uncond)
15351535
x0_text = convert_to_x0(latents_step, noise_pred_text)
15361536
x0_perturb = convert_to_x0(latents_step, noise_pred_perturb)
1537+
x0_isolated = convert_to_x0(latents_step, noise_pred_isolated)
15371538

15381539
# Delta formulation
15391540
cfg_delta = (guidance_scale - 1) * (x0_text - x0_uncond)
15401541
stg_delta = stg_scale * (x0_text - x0_perturb)
1542+
video_modality_delta = (modality_scale - 1) * (x0_text - x0_isolated)
15411543

1542-
x0_combined = x0_text + cfg_delta + stg_delta
1544+
x0_combined = x0_text + cfg_delta + stg_delta + video_modality_delta
15431545

15441546
# Apply guidance rescale if needed
15451547
if guidance_rescale > 0:
@@ -1549,16 +1551,18 @@ def convert_to_vel(lat, x0):
15491551
noise_pred = convert_to_vel(latents_step, x0_combined)
15501552

15511553
# Audio guidance
1552-
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 3, axis=0)
1554+
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb, noise_pred_audio_isolated = jnp.split(noise_pred_audio, 4, axis=0)
15531555

15541556
x0_audio_uncond = convert_to_x0(audio_latents_step, noise_pred_audio_uncond)
15551557
x0_audio_text = convert_to_x0(audio_latents_step, noise_pred_audio_text)
15561558
x0_audio_perturb = convert_to_x0(audio_latents_step, noise_pred_audio_perturb)
1559+
x0_audio_isolated = convert_to_x0(audio_latents_step, noise_pred_audio_isolated)
15571560

15581561
cfg_audio_delta = (audio_guidance_scale - 1 if audio_guidance_scale is not None else guidance_scale - 1) * (x0_audio_text - x0_audio_uncond)
15591562
stg_audio_delta = (audio_stg_scale if audio_stg_scale is not None else stg_scale) * (x0_audio_text - x0_audio_perturb)
1563+
audio_modality_delta = (audio_modality_scale - 1 if audio_modality_scale is not None else modality_scale - 1) * (x0_audio_text - x0_audio_isolated)
15601564

1561-
x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta
1565+
x0_audio_combined = x0_audio_text + cfg_audio_delta + stg_audio_delta + audio_modality_delta
15621566

15631567
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined)
15641568

@@ -1789,13 +1793,17 @@ def transformer_forward_pass(
17891793
else:
17901794
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
17911795

1796+
N = latents.shape[0] // 4
1797+
modality_mask = jnp.concatenate([jnp.ones((3 * N, 1, 1, 1), dtype=latents.dtype), jnp.zeros((N, 1, 1, 1), dtype=latents.dtype)], axis=0)
1798+
17921799
noise_pred, noise_pred_audio = transformer(
17931800
hidden_states=latents,
17941801
encoder_hidden_states=encoder_hidden_states,
17951802
timestep=timestep,
17961803
sigma=sigma,
17971804
audio_sigma=audio_sigma,
17981805
encoder_attention_mask=encoder_attention_mask,
1806+
modality_mask=modality_mask,
17991807
num_frames=latent_num_frames,
18001808
height=latent_height,
18011809
width=latent_width,

0 commit comments

Comments
 (0)