Skip to content

Commit 98982c7

Browse files
committed
transformer
1 parent 85df9d2 commit 98982c7

1 file changed

Lines changed: 135 additions & 31 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 135 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def __init__(
9999
norm_elementwise_affine: bool = False,
100100
norm_eps: float = 1e-6,
101101
rope_type: str = "interleaved",
102+
video_gated_attn: bool = False,
103+
audio_gated_attn: bool = False,
102104
dtype: jnp.dtype = jnp.float32,
103105
weights_dtype: jnp.dtype = jnp.float32,
104106
mesh: jax.sharding.Mesh = None,
@@ -142,6 +144,7 @@ def __init__(
142144
rope_type=rope_type,
143145
flash_block_sizes=flash_block_sizes,
144146
flash_min_seq_length=flash_min_seq_length,
147+
gated_attn=video_gated_attn,
145148
)
146149

147150
self.audio_norm1 = nnx.RMSNorm(
@@ -168,6 +171,7 @@ def __init__(
168171
rope_type=rope_type,
169172
flash_block_sizes=flash_block_sizes,
170173
flash_min_seq_length=flash_min_seq_length,
174+
gated_attn=audio_gated_attn,
171175
)
172176

173177
# 2. Prompt Cross-Attention
@@ -195,6 +199,7 @@ def __init__(
195199
attention_kernel=self.attention_kernel,
196200
rope_type=rope_type,
197201
flash_block_sizes=flash_block_sizes,
202+
gated_attn=video_gated_attn,
198203
)
199204

200205
self.audio_norm2 = nnx.RMSNorm(
@@ -222,6 +227,7 @@ def __init__(
222227
rope_type=rope_type,
223228
flash_block_sizes=flash_block_sizes,
224229
flash_min_seq_length=flash_min_seq_length,
230+
gated_attn=audio_gated_attn,
225231
)
226232

227233
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -250,6 +256,7 @@ def __init__(
250256
rope_type=rope_type,
251257
flash_block_sizes=flash_block_sizes,
252258
flash_min_seq_length=0,
259+
gated_attn=video_gated_attn,
253260
)
254261

255262
self.video_to_audio_norm = nnx.RMSNorm(
@@ -277,6 +284,7 @@ def __init__(
277284
rope_type=rope_type,
278285
flash_block_sizes=flash_block_sizes,
279286
flash_min_seq_length=flash_min_seq_length,
287+
gated_attn=audio_gated_attn,
280288
)
281289

282290
# 4. Feed Forward
@@ -344,6 +352,9 @@ def __call__(
344352
temb_ca_audio_scale_shift: jax.Array,
345353
temb_ca_gate: jax.Array,
346354
temb_ca_audio_gate: jax.Array,
355+
temb_prompt: Optional[jax.Array] = None,
356+
temb_prompt_audio: Optional[jax.Array] = None,
357+
modality_mask: Optional[jax.Array] = None,
347358
# RoPE
348359
video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
349360
audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
@@ -354,6 +365,7 @@ def __call__(
354365
audio_encoder_attention_mask: Optional[jax.Array] = None,
355366
a2v_cross_attention_mask: Optional[jax.Array] = None,
356367
v2a_cross_attention_mask: Optional[jax.Array] = None,
368+
perturbation_mask: Optional[jax.Array] = None,
357369
) -> Tuple[jax.Array, jax.Array]:
358370
batch_size = hidden_states.shape[0]
359371

@@ -493,6 +505,10 @@ def __call__(
493505
k_rotary_emb=ca_audio_rotary_emb,
494506
attention_mask=a2v_cross_attention_mask,
495507
)
508+
509+
if perturbation_mask is not None:
510+
a2v_attn_hidden_states = mod_norm_audio_hidden_states + perturbation_mask * (a2v_attn_hidden_states - mod_norm_audio_hidden_states)
511+
496512
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
497513

498514
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
@@ -509,6 +525,10 @@ def __call__(
509525
)
510526
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
511527

528+
if modality_mask is not None:
529+
hidden_states = hidden_states * modality_mask
530+
audio_hidden_states = audio_hidden_states * modality_mask
531+
512532
# 4. Feedforward
513533
norm_hidden_states = self.norm3(hidden_states)
514534
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
@@ -541,6 +561,8 @@ def __init__(
541561
pos_embed_max_pos: int = 20,
542562
base_height: int = 2048,
543563
base_width: int = 2048,
564+
gated_attn: bool = False,
565+
cross_attn_mod: bool = False,
544566
audio_in_channels: int = 128, # Audio Arguments
545567
audio_out_channels: Optional[int] = 128,
546568
audio_patch_size: int = 1,
@@ -552,11 +574,14 @@ def __init__(
552574
audio_pos_embed_max_pos: int = 20,
553575
audio_sampling_rate: int = 16000,
554576
audio_hop_length: int = 160,
577+
audio_gated_attn: bool = False,
578+
audio_cross_attn_mod: bool = False,
555579
num_layers: int = 48, # Shared arguments
556580
activation_fn: str = "gelu",
557581
norm_elementwise_affine: bool = False,
558582
norm_eps: float = 1e-6,
559583
caption_channels: int = 3840,
584+
audio_caption_channels: Optional[int] = None,
560585
attention_bias: bool = True,
561586
attention_out_bias: bool = True,
562587
rope_theta: float = 10000.0,
@@ -565,6 +590,8 @@ def __init__(
565590
timestep_scale_multiplier: int = 1000,
566591
cross_attn_timestep_scale_multiplier: int = 1000,
567592
rope_type: str = "interleaved",
593+
use_prompt_embeddings: bool = True,
594+
perturbed_attn: bool = False,
568595
dtype: jnp.dtype = jnp.float32,
569596
weights_dtype: jnp.dtype = jnp.float32,
570597
mesh: jax.sharding.Mesh = None,
@@ -655,33 +682,59 @@ def __init__(
655682
)
656683

657684
# 2. Prompt embeddings
658-
self.caption_projection = NNXPixArtAlphaTextProjection(
659-
rngs=rngs,
660-
in_features=self.caption_channels,
661-
hidden_size=inner_dim,
662-
dtype=self.dtype,
663-
weights_dtype=self.weights_dtype,
664-
)
665-
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
666-
rngs=rngs,
667-
in_features=self.caption_channels,
668-
hidden_size=audio_inner_dim,
669-
dtype=self.dtype,
670-
weights_dtype=self.weights_dtype,
671-
)
685+
if self.use_prompt_embeddings:
686+
self.caption_projection = NNXPixArtAlphaTextProjection(
687+
rngs=rngs,
688+
in_features=self.caption_channels,
689+
hidden_size=inner_dim,
690+
dtype=self.dtype,
691+
weights_dtype=self.weights_dtype,
692+
)
693+
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
694+
rngs=rngs,
695+
in_features=self.caption_channels,
696+
hidden_size=audio_inner_dim,
697+
dtype=self.dtype,
698+
weights_dtype=self.weights_dtype,
699+
)
700+
else:
701+
self.caption_projection = None
702+
self.audio_caption_projection = None
703+
704+
if self.cross_attn_mod:
705+
self.prompt_adaln = LTX2AdaLayerNormSingle(
706+
rngs=rngs,
707+
embedding_dim=inner_dim,
708+
num_mod_params=2,
709+
use_additional_conditions=False,
710+
dtype=self.dtype,
711+
weights_dtype=self.weights_dtype,
712+
)
713+
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
714+
rngs=rngs,
715+
embedding_dim=audio_inner_dim,
716+
num_mod_params=2,
717+
use_additional_conditions=False,
718+
dtype=self.dtype,
719+
weights_dtype=self.weights_dtype,
720+
)
721+
672722
# 3. Timestep Modulation Params and Embedding
723+
video_time_emb_mod_params = 9 if cross_attn_mod else 6
724+
audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6
725+
673726
self.time_embed = LTX2AdaLayerNormSingle(
674727
rngs=rngs,
675728
embedding_dim=inner_dim,
676-
num_mod_params=6,
729+
num_mod_params=video_time_emb_mod_params,
677730
use_additional_conditions=False,
678731
dtype=self.dtype,
679732
weights_dtype=self.weights_dtype,
680733
)
681734
self.audio_time_embed = LTX2AdaLayerNormSingle(
682735
rngs=rngs,
683736
embedding_dim=audio_inner_dim,
684-
num_mod_params=6,
737+
num_mod_params=audio_time_emb_mod_params,
685738
use_additional_conditions=False,
686739
dtype=self.dtype,
687740
weights_dtype=self.weights_dtype,
@@ -722,11 +775,11 @@ def __init__(
722775
# 3. Output Layer Scale/Shift Modulation parameters
723776
param_rng = rngs.params()
724777
self.scale_shift_table = nnx.Param(
725-
jax.random.normal(param_rng, (2, inner_dim), dtype=self.weights_dtype) / jnp.sqrt(inner_dim),
778+
jax.random.normal(param_rng, (6, inner_dim), dtype=self.weights_dtype) / jnp.sqrt(inner_dim),
726779
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")),
727780
)
728781
self.audio_scale_shift_table = nnx.Param(
729-
jax.random.normal(param_rng, (2, audio_inner_dim), dtype=self.weights_dtype) / jnp.sqrt(audio_inner_dim),
782+
jax.random.normal(param_rng, (6, audio_inner_dim), dtype=self.weights_dtype) / jnp.sqrt(audio_inner_dim),
730783
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, "embed")),
731784
)
732785

@@ -899,6 +952,8 @@ def __call__(
899952
audio_encoder_hidden_states: jax.Array,
900953
timestep: jax.Array,
901954
audio_timestep: Optional[jax.Array] = None,
955+
sigma: Optional[jax.Array] = None,
956+
audio_sigma: Optional[jax.Array] = None,
902957
encoder_attention_mask: Optional[jax.Array] = None,
903958
audio_encoder_attention_mask: Optional[jax.Array] = None,
904959
num_frames: Optional[int] = None,
@@ -909,7 +964,10 @@ def __call__(
909964
video_coords: Optional[jax.Array] = None,
910965
audio_coords: Optional[jax.Array] = None,
911966
attention_kwargs: Optional[Dict[str, Any]] = None,
967+
use_cross_timestep: bool = False,
968+
modality_mask: Optional[jax.Array] = None,
912969
return_dict: bool = True,
970+
perturbation_mask: Optional[jax.Array] = None,
913971
) -> Any:
914972
# Determine timestep for audio.
915973
audio_timestep = audio_timestep if audio_timestep is not None else timestep
@@ -961,12 +1019,36 @@ def __call__(
9611019
temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1])
9621020
audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1])
9631021

1022+
if self.cross_attn_mod and sigma is not None:
1023+
audio_sigma = audio_sigma if audio_sigma is not None else sigma
1024+
temb_prompt, _ = self.prompt_adaln(
1025+
sigma.flatten(),
1026+
hidden_dtype=hidden_states.dtype,
1027+
)
1028+
temb_prompt_audio, _ = self.audio_prompt_adaln(
1029+
audio_sigma.flatten(),
1030+
hidden_dtype=audio_hidden_states.dtype,
1031+
)
1032+
temb_prompt = temb_prompt.reshape(batch_size, -1, temb_prompt.shape[-1])
1033+
temb_prompt_audio = temb_prompt_audio.reshape(batch_size, -1, temb_prompt_audio.shape[-1])
1034+
else:
1035+
temb_prompt = None
1036+
temb_prompt_audio = None
1037+
1038+
if use_cross_timestep:
1039+
assert sigma is not None and audio_sigma is not None, "sigma and audio_sigma must be provided when use_cross_timestep is True"
1040+
video_ca_timestep = audio_sigma.flatten()
1041+
audio_ca_timestep = sigma.flatten()
1042+
else:
1043+
video_ca_timestep = timestep.flatten()
1044+
audio_ca_timestep = audio_timestep.flatten() if audio_timestep is not None else timestep.flatten()
1045+
9641046
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
965-
timestep.flatten(),
1047+
video_ca_timestep,
9661048
hidden_dtype=hidden_states.dtype,
9671049
)
9681050
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
969-
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
1051+
video_ca_timestep * timestep_cross_attn_gate_scale_factor,
9701052
hidden_dtype=hidden_states.dtype,
9711053
)
9721054
video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape(
@@ -975,11 +1057,11 @@ def __call__(
9751057
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
9761058

9771059
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
978-
audio_timestep.flatten(),
1060+
audio_ca_timestep,
9791061
hidden_dtype=audio_hidden_states.dtype,
9801062
)
9811063
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
982-
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
1064+
audio_ca_timestep * timestep_cross_attn_gate_scale_factor,
9831065
hidden_dtype=audio_hidden_states.dtype,
9841066
)
9851067
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape(
@@ -988,14 +1070,26 @@ def __call__(
9881070
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
9891071

9901072
# 4. Prepare prompt embeddings
991-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
992-
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
1073+
if self.use_prompt_embeddings and self.caption_projection is not None:
1074+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
1075+
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
9931076

994-
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
995-
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
1077+
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
1078+
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
1079+
1080+
# Construct perturbation_mask_per_layer for STG
1081+
if perturbation_mask is None:
1082+
perturbation_mask_per_layer = jnp.ones((self.num_layers, batch_size, 1, 1), dtype=self.dtype)
1083+
else:
1084+
masks = jnp.ones((self.num_layers, batch_size, 1, 1), dtype=self.dtype)
1085+
for i in self.spatio_temporal_guidance_blocks:
1086+
if i < self.num_layers:
1087+
masks = masks.at[i].set(perturbation_mask)
1088+
perturbation_mask_per_layer = masks
9961089

9971090
# 5. Run transformer blocks
998-
def scan_fn(carry, block):
1091+
def scan_fn(carry, block_and_mask):
1092+
block, mask = block_and_mask
9991093
hidden_states, audio_hidden_states, rngs_carry = carry
10001094
with jax.named_scope("Transformer Layer"):
10011095
hidden_states_out, audio_hidden_states_out = block(
@@ -1009,12 +1103,16 @@ def scan_fn(carry, block):
10091103
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
10101104
temb_ca_gate=video_cross_attn_a2v_gate,
10111105
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1106+
temb_prompt=temb_prompt,
1107+
temb_prompt_audio=temb_prompt_audio,
10121108
video_rotary_emb=video_rotary_emb,
10131109
audio_rotary_emb=audio_rotary_emb,
10141110
ca_video_rotary_emb=video_cross_attn_rotary_emb,
10151111
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1016-
encoder_attention_mask=encoder_attention_mask,
1017-
audio_encoder_attention_mask=audio_encoder_attention_mask,
1112+
a2v_cross_attention_mask=encoder_attention_mask,
1113+
v2a_cross_attention_mask=audio_encoder_attention_mask,
1114+
perturbation_mask=mask,
1115+
modality_mask=modality_mask,
10181116
)
10191117
return (
10201118
hidden_states_out.astype(hidden_states.dtype),
@@ -1034,9 +1132,10 @@ def scan_fn(carry, block):
10341132
in_axes=(nnx.Carry, 0),
10351133
out_axes=(nnx.Carry, 0),
10361134
transform_metadata={nnx.PARTITION_NAME: "layers"},
1037-
)(carry, self.transformer_blocks)
1135+
)(carry, (self.transformer_blocks, perturbation_mask_per_layer))
10381136
else:
1039-
for block in self.transformer_blocks:
1137+
for i, block in enumerate(self.transformer_blocks):
1138+
mask = perturbation_mask_per_layer[i] if perturbation_mask_per_layer is not None else None
10401139
hidden_states, audio_hidden_states = block(
10411140
hidden_states=hidden_states,
10421141
audio_hidden_states=audio_hidden_states,
@@ -1048,12 +1147,17 @@ def scan_fn(carry, block):
10481147
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
10491148
temb_ca_gate=video_cross_attn_a2v_gate,
10501149
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1150+
temb_prompt=temb_prompt,
1151+
temb_prompt_audio=temb_prompt_audio,
10511152
video_rotary_emb=video_rotary_emb,
10521153
audio_rotary_emb=audio_rotary_emb,
10531154
ca_video_rotary_emb=video_cross_attn_rotary_emb,
10541155
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
10551156
encoder_attention_mask=encoder_attention_mask,
10561157
audio_encoder_attention_mask=audio_encoder_attention_mask,
1158+
a2v_cross_attention_mask=encoder_attention_mask,
1159+
v2a_cross_attention_mask=audio_encoder_attention_mask,
1160+
perturbation_mask=mask,
10571161
)
10581162

10591163
# 6. Output layers

0 commit comments

Comments
 (0)