Skip to content

Commit 3579b56

Browse files
committed
fix
1 parent e2361cf commit 3579b56

1 file changed

Lines changed: 10 additions & 5 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@ def rename_for_ltx2_transformer(key):
2929
"""
3030
key = key.replace("patchify_proj", "proj_in")
3131
key = key.replace("audio_patchify_proj", "audio_proj_in")
32-
33-
# if "caption_projection" in key:
34-
# key = key.replace("caption_projection", "audio_caption_projection")
35-
32+
key = key.replace("norm_final", "norm_out")
33+
34+
# Handle scale_shift_table
35+
# PyTorch: adaLN_modulation.1.weight/bias -> scale_shift_table
36+
if "adaLN_modulation.1" in key:
37+
key = key.replace("adaLN_modulation.1", "scale_shift_table")
38+
39+
# Handle autoencoder_kl_ltx2 specific renames if any, but this is for transformer usually.
40+
3641
# Handle audio_ff.net_0.proj -> audio_ff.net_0
3742
if "audio_ff" in key and "proj" in key:
3843
key = key.replace(".proj", "")
@@ -203,7 +208,7 @@ def load_vae_weights(
203208
tensors[k] = torch2jax(f.get_tensor(k))
204209
else:
205210
loaded_state_dict = torch.load(ckpt_path, map_location="cpu")
206-
for k, v in loaded_state_dict.items():
211+
for k, v in loaded_state_dict.items():
207212
tensors[k] = torch2jax(v)
208213

209214
print("\nDEBUG: Top 20 keys from VAE Checkpoint (tensors):")

0 commit comments

Comments
 (0)