Skip to content

Commit 3ec5421

Browse files
committed
fix
1 parent 62634b0 commit 3ec5421

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def rename_for_ltx2_transformer(key):
3636
# rename_key changes adaLN_modulation.1 -> adaLN_modulation_1
3737
if "adaLN_modulation_1" in key:
3838
key = key.replace("adaLN_modulation_1", "scale_shift_table")
39+
40+
# Handle video_a2v_cross_attn_scale_shift_table (caption_modulator?)
41+
# Checkpoint: caption_modulator.1.weight
42+
if "caption_modulator_1" in key:
43+
key = key.replace("caption_modulator_1", "video_a2v_cross_attn_scale_shift_table")
44+
45+
# Audio caption modulator?
46+
# Checkpoint: audio_caption_modulator.1.weight (Guessing name)
47+
# Let's inspect checkpoint keys for clues if this guess fails.
48+
if "audio_caption_modulator_1" in key:
49+
key = key.replace("audio_caption_modulator_1", "audio_a2v_cross_attn_scale_shift_table")
50+
3951

4052
# Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
4153

@@ -68,11 +80,24 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
6880
if "transformer_blocks" in pt_tuple_key:
6981
pass # Already handled above or matches standard format
7082

83+
# Handle scale_shift_table keys - they are Params, not Linear layers, so no 'kernel' suffix needed
84+
# We might have renamed them to scale_shift_table already in rename_for_ltx2_transformer
85+
if "scale_shift_table" in pt_tuple_key[-1] or "scale_shift_table" in pt_tuple_key:
86+
# if we renamed it to ends with scale_shift_table, use it directly
87+
# But rename_key_and_reshape might have added kernel?
88+
pass
89+
7190
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
7291

7392
# RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
7493
flax_key_str = [str(k) for k in flax_key]
7594

95+
# Fix scale_shift_table mapping if it got 'kernel' appended
96+
if "scale_shift_table" in flax_key_str:
97+
# if last is kernel/weight, remove it
98+
if flax_key_str[-1] in ["kernel", "weight"]:
99+
flax_key_str.pop()
100+
76101
# Helper to replace last occurrence
77102
def replace_suffix(lst, old, new):
78103
if lst and lst[-1] == old:

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ def __init__(
158158
rope_type=rope_type,
159159
)
160160

161+
# Scale Shift Tables
162+
self.scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim))
163+
self.audio_scale_shift_table = nnx.Param(
164+
jax.random.normal(rngs.params(), (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim)
165+
)
166+
self.video_a2v_cross_attn_scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (5, self.dim), dtype=weights_dtype))
167+
self.audio_a2v_cross_attn_scale_shift_table = nnx.Param(jax.random.normal(rngs.params(), (5, audio_dim), dtype=weights_dtype))
168+
161169
# 2. Prompt Cross-Attention
162170
self.norm2 = nnx.RMSNorm(
163171
self.dim,
@@ -807,7 +815,7 @@ def init_block(rngs):
807815
# 6. Output layers
808816
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
809817
self.norm_out = nnx.LayerNorm(
810-
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
818+
inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
811819
)
812820
self.proj_out = nnx.Linear(
813821
inner_dim,
@@ -820,7 +828,7 @@ def init_block(rngs):
820828
)
821829

822830
self.audio_norm_out = nnx.LayerNorm(
823-
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
831+
audio_inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
824832
)
825833
self.audio_proj_out = nnx.Linear(
826834
audio_inner_dim,

0 commit comments

Comments
 (0)