Skip to content

Commit 0d0babb

Browse files
committed
ltx2.3
1 parent 98982c7 commit 0d0babb

3 files changed

Lines changed: 772 additions & 99 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def __call__(
365365
audio_encoder_attention_mask: Optional[jax.Array] = None,
366366
a2v_cross_attention_mask: Optional[jax.Array] = None,
367367
v2a_cross_attention_mask: Optional[jax.Array] = None,
368+
use_a2v_cross_attention: bool = True,
369+
use_v2a_cross_attention: bool = True,
368370
perturbation_mask: Optional[jax.Array] = None,
369371
) -> Tuple[jax.Array, jax.Array]:
370372
batch_size = hidden_states.shape[0]
@@ -497,33 +499,35 @@ def __call__(
497499
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale) + video_a2v_ca_shift
498500
mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale) + audio_a2v_ca_shift
499501

500-
with jax.named_scope("Audio-to-Video Cross-Attention"):
501-
a2v_attn_hidden_states = self.audio_to_video_attn(
502-
mod_norm_hidden_states,
503-
encoder_hidden_states=mod_norm_audio_hidden_states,
504-
rotary_emb=ca_video_rotary_emb,
505-
k_rotary_emb=ca_audio_rotary_emb,
506-
attention_mask=a2v_cross_attention_mask,
507-
)
508-
509-
if perturbation_mask is not None:
510-
a2v_attn_hidden_states = mod_norm_audio_hidden_states + perturbation_mask * (a2v_attn_hidden_states - mod_norm_audio_hidden_states)
502+
if use_a2v_cross_attention:
503+
with jax.named_scope("Audio-to-Video Cross-Attention"):
504+
a2v_attn_hidden_states = self.audio_to_video_attn(
505+
mod_norm_hidden_states,
506+
encoder_hidden_states=mod_norm_audio_hidden_states,
507+
rotary_emb=ca_video_rotary_emb,
508+
k_rotary_emb=ca_audio_rotary_emb,
509+
attention_mask=a2v_cross_attention_mask,
510+
)
511511

512-
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
512+
if perturbation_mask is not None:
513+
a2v_attn_hidden_states = mod_norm_audio_hidden_states + perturbation_mask * (a2v_attn_hidden_states - mod_norm_audio_hidden_states)
514+
515+
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
513516

514517
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
515518
mod_norm_hidden_states_v2a = norm_hidden_states * (1 + video_v2a_ca_scale) + video_v2a_ca_shift
516519
mod_norm_audio_hidden_states_v2a = norm_audio_hidden_states * (1 + audio_v2a_ca_scale) + audio_v2a_ca_shift
517520

518-
with jax.named_scope("Video-to-Audio Cross-Attention"):
519-
v2a_attn_hidden_states = self.video_to_audio_attn(
520-
mod_norm_audio_hidden_states_v2a,
521-
encoder_hidden_states=mod_norm_hidden_states_v2a,
522-
rotary_emb=ca_audio_rotary_emb,
523-
k_rotary_emb=ca_video_rotary_emb,
524-
attention_mask=v2a_cross_attention_mask,
525-
)
526-
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
521+
if use_v2a_cross_attention:
522+
with jax.named_scope("Video-to-Audio Cross-Attention"):
523+
v2a_attn_hidden_states = self.video_to_audio_attn(
524+
mod_norm_audio_hidden_states_v2a,
525+
encoder_hidden_states=mod_norm_hidden_states_v2a,
526+
rotary_emb=ca_audio_rotary_emb,
527+
k_rotary_emb=ca_video_rotary_emb,
528+
attention_mask=v2a_cross_attention_mask,
529+
)
530+
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
527531

528532
if modality_mask is not None:
529533
hidden_states = hidden_states * modality_mask
@@ -966,6 +970,7 @@ def __call__(
966970
attention_kwargs: Optional[Dict[str, Any]] = None,
967971
use_cross_timestep: bool = False,
968972
modality_mask: Optional[jax.Array] = None,
973+
isolate_modalities: bool = False,
969974
return_dict: bool = True,
970975
perturbation_mask: Optional[jax.Array] = None,
971976
) -> Any:
@@ -1111,6 +1116,8 @@ def scan_fn(carry, block_and_mask):
11111116
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
11121117
a2v_cross_attention_mask=encoder_attention_mask,
11131118
v2a_cross_attention_mask=audio_encoder_attention_mask,
1119+
use_a2v_cross_attention=not isolate_modalities,
1120+
use_v2a_cross_attention=not isolate_modalities,
11141121
perturbation_mask=mask,
11151122
modality_mask=modality_mask,
11161123
)
@@ -1157,6 +1164,8 @@ def scan_fn(carry, block_and_mask):
11571164
audio_encoder_attention_mask=audio_encoder_attention_mask,
11581165
a2v_cross_attention_mask=encoder_attention_mask,
11591166
v2a_cross_attention_mask=audio_encoder_attention_mask,
1167+
use_a2v_cross_attention=not isolate_modalities,
1168+
use_v2a_cross_attention=not isolate_modalities,
11601169
perturbation_mask=mask,
11611170
)
11621171

0 commit comments

Comments
 (0)