@@ -29,10 +29,15 @@ def rename_for_ltx2_transformer(key):
2929 """
3030 key = key .replace ("patchify_proj" , "proj_in" )
3131 key = key .replace ("audio_patchify_proj" , "audio_proj_in" )
32-
33- # if "caption_projection" in key:
34- # key = key.replace("caption_projection", "audio_caption_projection")
35-
32+ key = key .replace ("norm_final" , "norm_out" )
33+
34+ # Handle scale_shift_table
35+ # PyTorch: adaLN_modulation.1.weight/bias -> scale_shift_table
36+ if "adaLN_modulation.1" in key :
37+ key = key .replace ("adaLN_modulation.1" , "scale_shift_table" )
38+
39+ # Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
40+
3641 # Handle audio_ff.net_0.proj -> audio_ff.net_0
3742 if "audio_ff" in key and "proj" in key :
3843 key = key .replace (".proj" , "" )
@@ -203,7 +208,7 @@ def load_vae_weights(
203208 tensors [k ] = torch2jax (f .get_tensor (k ))
204209 else :
205210 loaded_state_dict = torch .load (ckpt_path , map_location = "cpu" )
206- for k , v in loaded_state_dict .items ():
211+ for k , v in loaded_state_dict .items ():
207212 tensors [k ] = torch2jax (v )
208213
209214 print ("\n DEBUG: Top 20 keys from VAE Checkpoint (tensors):" )
0 commit comments