@@ -37,8 +37,6 @@ def rename_for_ltx2_transformer(key):
3737 if "adaLN_modulation_1" in key :
3838 key = key .replace ("adaLN_modulation_1" , "scale_shift_table" )
3939
40- # Handle video_a2v_cross_attn_scale_shift_table (caption_modulator?)
41- # Checkpoint: caption_modulator.1.weight
4240 if "caption_modulator_1" in key :
4341 key = key .replace ("caption_modulator_1" , "video_a2v_cross_attn_scale_shift_table" )
4442
@@ -47,6 +45,28 @@ def rename_for_ltx2_transformer(key):
4745 # Let's inspect checkpoint keys for clues if this guess fails.
4846 if "audio_caption_modulator_1" in key :
4947 key = key .replace ("audio_caption_modulator_1" , "audio_a2v_cross_attn_scale_shift_table" )
48+
49+ # Handle audio_caption_projection
50+ # Checkpoint: audio_caption_projection.linear_1.weight
51+ # Flax: audio_caption_projection.linear_1.kernel
52+ # rename_key_and_reshape_tensor catches 'weight' -> 'kernel', but maybe something else renaming it?
53+ # No explicit rename needed if it's already linear_1/linear_2 unless name mismatch.
54+
55+ # Handle global norms (norm_out, audio_norm_out)
56+ # Checkpoint: norm_final -> norm_out (already handled)
57+ # Checkpoint also has audio_norm_final -> audio_norm_out?
58+ if "audio_norm_final" in key :
59+ key = key .replace ("audio_norm_final" , "audio_norm_out" )
60+
61+ # Handle time_embed/audio_time_embed
62+ # Checkpoint: time_embed.emb.timestep_embedder.linear_1.weight
63+ # Flax: time_embed.emb.timestep_embedder.linear_1.kernel
64+ # If checkpoint uses different name structure?
65+ # time_embed.emb.timestep_embedder -> time_embed.emb.timestep_embedder (seems OK)
66+
67+ # Handle av_cross_attn...
68+ # These seem fine in name but verify if they are Linear or Conv? Linear.
69+
5070
5171
5272 # Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
@@ -141,6 +161,12 @@ def replace_suffix(lst, old, new):
141161 if "to_out" in str (flax_key ) and "kernel" in str (flax_key ) and block_index == 18 and "attn1" in str (flax_key ):
142162 print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (Block 18 attn1)" )
143163
164+ if "proj_in" in str (flax_key ):
165+ print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (proj_in)" )
166+
167+ if "scale_shift_table" in str (flax_key ):
168+ print (f"DEBUG: Mapped { pt_tuple_key } -> { flax_key } (scale_shift_table)" )
169+
144170 return flax_key , flax_tensor
145171
146172def load_sharded_checkpoint (pretrained_model_name_or_path , subfolder , device ):
@@ -338,6 +364,11 @@ def load_vae_weights(
338364 # Check if next part is 'conv'
339365 if i + 1 < len (pt_tuple_key ) and pt_tuple_key [i + 1 ] == "conv" :
340366 pass # already has conv
367+ elif pt_list [- 2 ] == "conv" : # Check previous injection
368+ pass # already injected conv in previous step (if part was conv1/conv2/conv)
369+ # Also avoid injecting if part ITSELF is 'conv'
370+ elif part == "conv" :
371+ pass
341372 else :
342373 pt_list .append ("conv" )
343374 else :
0 commit comments