Skip to content

Commit 3808925

Browse files
committed
fix
1 parent ff2c164 commit 3808925

1 file changed

Lines changed: 33 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def rename_for_ltx2_transformer(key):
3737
if "adaLN_modulation_1" in key:
3838
key = key.replace("adaLN_modulation_1", "scale_shift_table")
3939

40-
# Handle video_a2v_cross_attn_scale_shift_table (caption_modulator?)
41-
# Checkpoint: caption_modulator.1.weight
4240
if "caption_modulator_1" in key:
4341
key = key.replace("caption_modulator_1", "video_a2v_cross_attn_scale_shift_table")
4442

@@ -47,6 +45,28 @@ def rename_for_ltx2_transformer(key):
4745
# Let's inspect checkpoint keys for clues if this guess fails.
4846
if "audio_caption_modulator_1" in key:
4947
key = key.replace("audio_caption_modulator_1", "audio_a2v_cross_attn_scale_shift_table")
48+
49+
# Handle audio_caption_projection
50+
# Checkpoint: audio_caption_projection.linear_1.weight
51+
# Flax: audio_caption_projection.linear_1.kernel
52+
# rename_key_and_reshape_tensor catches 'weight' -> 'kernel', but maybe something else renaming it?
53+
# No explicit rename needed if it's already linear_1/linear_2 unless name mismatch.
54+
55+
# Handle global norms (norm_out, audio_norm_out)
56+
# Checkpoint: norm_final -> norm_out (already handled)
57+
# Checkpoint also has audio_norm_final -> audio_norm_out?
58+
if "audio_norm_final" in key:
59+
key = key.replace("audio_norm_final", "audio_norm_out")
60+
61+
# Handle time_embed/audio_time_embed
62+
# Checkpoint: time_embed.emb.timestep_embedder.linear_1.weight
63+
# Flax: time_embed.emb.timestep_embedder.linear_1.kernel
64+
# If checkpoint uses different name structure?
65+
# time_embed.emb.timestep_embedder -> time_embed.emb.timestep_embedder (seems OK)
66+
67+
# Handle av_cross_attn...
68+
# These seem fine in name but verify if they are Linear or Conv? Linear.
69+
5070

5171

5272
# Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
@@ -141,6 +161,12 @@ def replace_suffix(lst, old, new):
141161
if "to_out" in str(flax_key) and "kernel" in str(flax_key) and block_index == 18 and "attn1" in str(flax_key):
142162
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (Block 18 attn1)")
143163

164+
if "proj_in" in str(flax_key):
165+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (proj_in)")
166+
167+
if "scale_shift_table" in str(flax_key):
168+
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (scale_shift_table)")
169+
144170
return flax_key, flax_tensor
145171

146172
def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
@@ -338,6 +364,11 @@ def load_vae_weights(
338364
# Check if next part is 'conv'
339365
if i + 1 < len(pt_tuple_key) and pt_tuple_key[i+1] == "conv":
340366
pass # already has conv
367+
elif pt_list[-2] == "conv": # Check previous injection
368+
pass # already injected conv in previous step (if part was conv1/conv2/conv)
369+
# Also avoid injecting if part ITSELF is 'conv'
370+
elif part == "conv":
371+
pass
341372
else:
342373
pt_list.append("conv")
343374
else:

0 commit comments

Comments
 (0)