Skip to content

Commit 0aab82a

Browse files
committed
refactor for adaln modulation
1 parent b9b4392 commit 0aab82a

3 files changed

Lines changed: 80 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def load_ltx2_3_checkpoint(pretrained_model_name_or_path: str, subfolder: str, d
2525
with safe_open(ckpt_path, framework="pt") as f:
2626
for k in f.keys():
2727
tensors[k] = torch2jax(f.get_tensor(k))
28-
return tensorsdef rename_for_ltx2_3_transformer(key):
28+
return tensors
29+
30+
31+
def rename_for_ltx2_3_transformer(key):
2932
"""
3033
Renames Diffusers LTX-2.3 keys to MaxDiffusion Flax LTX-2.3 keys.
3134
"""

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def __init__(
319319
)
320320

321321
key = rngs.params()
322-
k1, k2, k3, k4 = jax.random.split(key, 4)
322+
k1, k2, k3, k4, k5, k6 = jax.random.split(key, 6)
323323

324324
self.cross_attn_mod = cross_attn_mod
325325
table_size = 9 if cross_attn_mod else 6
@@ -339,6 +339,15 @@ def __init__(
339339
jax.random.normal(k4, (5, audio_dim), dtype=weights_dtype),
340340
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
341341
)
342+
if self.cross_attn_mod:
343+
self.prompt_scale_shift_table = nnx.Param(
344+
jax.random.normal(k5, (2, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim),
345+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
346+
)
347+
self.audio_prompt_scale_shift_table = nnx.Param(
348+
jax.random.normal(k6, (2, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim),
349+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
350+
)
342351

343352
def __call__(
344353
self,
@@ -353,6 +362,8 @@ def __call__(
353362
temb_ca_audio_scale_shift: jax.Array,
354363
temb_ca_gate: jax.Array,
355364
temb_ca_audio_gate: jax.Array,
365+
temb_prompt: Optional[jax.Array] = None,
366+
temb_prompt_audio: Optional[jax.Array] = None,
356367
# RoPE
357368
video_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
358369
audio_rotary_emb: Optional[Tuple[jax.Array, jax.Array]] = None,
@@ -445,6 +456,14 @@ def __call__(
445456
if getattr(self, "cross_attn_mod", False):
446457
norm_hidden_states = norm_hidden_states * (1 + scale_q) + shift_q
447458

459+
if getattr(self, "cross_attn_mod", False) and temb_prompt is not None:
460+
prompt_table_reshaped = jnp.expand_dims(self.prompt_scale_shift_table, axis=(0, 1))
461+
temb_prompt_reshaped = temb_prompt.reshape(batch_size, 1, 2, -1)
462+
prompt_ada_values = prompt_table_reshaped + temb_prompt_reshaped
463+
shift_text_kv = prompt_ada_values[:, :, 0, :]
464+
scale_text_kv = prompt_ada_values[:, :, 1, :]
465+
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
466+
448467
attn_hidden_states = self.attn2(
449468
norm_hidden_states,
450469
encoder_hidden_states=encoder_hidden_states,
@@ -461,6 +480,14 @@ def __call__(
461480
if getattr(self, "cross_attn_mod", False):
462481
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_q) + audio_shift_q
463482

483+
if getattr(self, "cross_attn_mod", False) and temb_prompt_audio is not None:
484+
audio_prompt_table_reshaped = jnp.expand_dims(self.audio_prompt_scale_shift_table, axis=(0, 1))
485+
temb_prompt_audio_reshaped = temb_prompt_audio.reshape(batch_size, 1, 2, -1)
486+
audio_prompt_ada_values = audio_prompt_table_reshaped + temb_prompt_audio_reshaped
487+
audio_shift_text_kv = audio_prompt_ada_values[:, :, 0, :]
488+
audio_scale_text_kv = audio_prompt_ada_values[:, :, 1, :]
489+
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
490+
464491
attn_audio_hidden_states = self.audio_attn2(
465492
norm_audio_hidden_states,
466493
encoder_hidden_states=audio_encoder_hidden_states,
@@ -785,6 +812,25 @@ def __init__(
785812
weights_dtype=self.weights_dtype,
786813
)
787814

815+
# 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3)
816+
if self.cross_attn_mod:
817+
self.prompt_adaln = LTX2AdaLayerNormSingle(
818+
rngs=rngs,
819+
embedding_dim=inner_dim,
820+
num_mod_params=2,
821+
use_additional_conditions=False,
822+
dtype=self.dtype,
823+
weights_dtype=self.weights_dtype,
824+
)
825+
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
826+
rngs=rngs,
827+
embedding_dim=audio_inner_dim,
828+
num_mod_params=2,
829+
use_additional_conditions=False,
830+
dtype=self.dtype,
831+
weights_dtype=self.weights_dtype,
832+
)
833+
788834
# 3. Output Layer Scale/Shift Modulation parameters
789835
param_rng = rngs.params()
790836
self.scale_shift_table = nnx.Param(
@@ -969,6 +1015,8 @@ def __call__(
9691015
audio_encoder_hidden_states: jax.Array,
9701016
timestep: jax.Array,
9711017
audio_timestep: Optional[jax.Array] = None,
1018+
sigma: Optional[jax.Array] = None,
1019+
audio_sigma: Optional[jax.Array] = None,
9721020
encoder_attention_mask: Optional[jax.Array] = None,
9731021
audio_encoder_attention_mask: Optional[jax.Array] = None,
9741022
num_frames: Optional[int] = None,
@@ -1032,6 +1080,22 @@ def __call__(
10321080
temb_audio = temb_audio.reshape(batch_size, -1, temb_audio.shape[-1])
10331081
audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1])
10341082

1083+
if self.cross_attn_mod and sigma is not None:
1084+
audio_sigma = audio_sigma if audio_sigma is not None else sigma
1085+
temb_prompt, _ = self.prompt_adaln(
1086+
sigma.flatten(),
1087+
hidden_dtype=hidden_states.dtype,
1088+
)
1089+
temb_prompt_audio, _ = self.audio_prompt_adaln(
1090+
audio_sigma.flatten(),
1091+
hidden_dtype=audio_hidden_states.dtype,
1092+
)
1093+
temb_prompt = temb_prompt.reshape(batch_size, -1, temb_prompt.shape[-1])
1094+
temb_prompt_audio = temb_prompt_audio.reshape(batch_size, -1, temb_prompt_audio.shape[-1])
1095+
else:
1096+
temb_prompt = None
1097+
temb_prompt_audio = None
1098+
10351099
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
10361100
timestep.flatten(),
10371101
hidden_dtype=hidden_states.dtype,
@@ -1094,6 +1158,8 @@ def scan_fn(carry, block_and_mask):
10941158
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
10951159
temb_ca_gate=video_cross_attn_a2v_gate,
10961160
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1161+
temb_prompt=temb_prompt,
1162+
temb_prompt_audio=temb_prompt_audio,
10971163
video_rotary_emb=video_rotary_emb,
10981164
audio_rotary_emb=audio_rotary_emb,
10991165
ca_video_rotary_emb=video_cross_attn_rotary_emb,
@@ -1135,6 +1201,8 @@ def scan_fn(carry, block_and_mask):
11351201
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
11361202
temb_ca_gate=video_cross_attn_a2v_gate,
11371203
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1204+
temb_prompt=temb_prompt,
1205+
temb_prompt_audio=temb_prompt_audio,
11381206
video_rotary_emb=video_rotary_emb,
11391207
audio_rotary_emb=audio_rotary_emb,
11401208
ca_video_rotary_emb=video_cross_attn_rotary_emb,

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,16 +1692,23 @@ def transformer_forward_pass(
16921692
audio_num_frames,
16931693
fps,
16941694
perturbation_mask=None,
1695+
sigma=None,
16951696
):
16961697
transformer = nnx.merge(graphdef, state)
16971698

16981699
# Expand timestep to batch size
16991700
timestep = jnp.expand_dims(timestep, 0).repeat(latents.shape[0])
1701+
1702+
if sigma is None:
1703+
sigma = timestep
1704+
else:
1705+
sigma = jnp.expand_dims(sigma, 0).repeat(latents.shape[0])
17001706

17011707
noise_pred, noise_pred_audio = transformer(
17021708
hidden_states=latents,
17031709
encoder_hidden_states=encoder_hidden_states,
17041710
timestep=timestep,
1711+
sigma=sigma,
17051712
encoder_attention_mask=encoder_attention_mask,
17061713
num_frames=latent_num_frames,
17071714
height=latent_height,

0 commit comments

Comments
 (0)