Skip to content

Commit fb20038

Browse files
committed
fixing conversion script
1 parent b18f8f5 commit fb20038

1 file changed

Lines changed: 67 additions & 29 deletions

File tree

src/maxdiffusion/scripts/convert_ltx2_vae_weights.py

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from huggingface_hub import hf_hub_download
1313
from flax import nnx
1414
from flax import traverse_util
15+
1516
def convert_ltx2_vae(hf_repo, output_path):
1617
# Load weights directly from Safetensors
1718
print(f"Downloading/Loading weights from {hf_repo}...")
@@ -48,40 +49,77 @@ def convert_ltx2_vae(hf_repo, output_path):
4849
dtype=jnp.float32,
4950
rngs=nnx.Rngs(0)
5051
)
51-
52-
# Get PyTorch state dict
53-
pt_state_dict = load_file(ckpt_path)
5452

55-
# Define mapping
56-
# We will need to map PT keys to Flax keys
57-
# Helper to print PT keys
58-
print("PyTorch Keys:")
59-
sorted_pt_keys = sorted(pt_state_dict.keys())
60-
for k in sorted_pt_keys:
61-
v = pt_state_dict[k]
62-
print(f"{k}: {v.shape}")
53+
print("\nMapping weights...")
54+
graphdef, state = nnx.split(model)
55+
params = state.filter(nnx.Param)
6356

64-
print("\nMaxDiffusion Keys (initialization):")
65-
# Get MaxDiffusion keys from initialized model
66-
# We need to run a dummy forward or init to get parameters if they are lazy,
67-
# but nnx.Module usually has them after init if shape is provided?
68-
# Wait, nnx modules need to be split to see params.
69-
graphdef, state = nnx.split(model); params = state.filter(nnx.Param)
70-
flat_params = traverse_util.flatten_dict(params.to_pure_dict())
71-
sorted_flat_keys = sorted(flat_params.keys())
72-
for k in sorted_flat_keys:
73-
v = flat_params[k]
74-
print(f"{k}: {v.shape}")
75-
76-
params = {}
57+
params_dict = params.to_pure_dict()
58+
flat_params = traverse_util.flatten_dict(params_dict)
7759

78-
# TODO: Implement the mapping logic here
79-
# This acts as a template for now
60+
new_params = {}
61+
62+
for key_tuple, value in flat_params.items():
63+
# Skip Rngs if any leak through
64+
if "rngs" in key_tuple or "count" in key_tuple or "key" in key_tuple:
65+
continue
66+
67+
# Construct PyTorch key
68+
pt_key_parts = []
69+
for p in key_tuple:
70+
if isinstance(p, int):
71+
pt_key_parts.append(str(p))
72+
else:
73+
pt_key_parts.append(p)
74+
75+
# Adjust property names from MaxDiffusion to Diffusers
76+
if pt_key_parts[-1] == "kernel":
77+
pt_key_parts[-1] = "weight"
78+
elif pt_key_parts[-1] == "scale":
79+
pt_key_parts[-1] = "weight"
80+
81+
pt_key = ".".join(pt_key_parts)
82+
83+
# Check if key exists in PT dict
84+
if pt_key not in pt_state_dict:
85+
print(f"Warning: {pt_key} not found in PyTorch state dict. Checking alternatives...")
86+
continue
87+
88+
pt_tensor = pt_state_dict[pt_key]
89+
np_array = pt_tensor.numpy()
90+
91+
# Handle shape mismatch (Transpose Conv3d weights)
92+
is_conv_weight = "conv" in pt_key and "weight" in pt_key and len(np_array.shape) == 5
93+
94+
if is_conv_weight:
95+
# PyTorch Conv3d: (Out, In, T, H, W)
96+
# JAX Conv: (T, H, W, In, Out)
97+
# Permutation: 0, 1, 2, 3, 4 -> 2, 3, 4, 1, 0
98+
np_array = np_array.transpose(2, 3, 4, 1, 0)
99+
100+
# Verify shape
101+
if np_array.shape != value.shape:
102+
# Handle singleton dimensions if they match in total size or one dim
103+
if np_array.shape == (1,) + value.shape:
104+
np_array = np_array.squeeze(0)
105+
elif value.shape == (1,) + np_array.shape:
106+
np_array = np_array[None]
107+
else:
108+
print(f"Shape mismatch for {pt_key}: PT {np_array.shape} vs Max {value.shape}")
109+
continue
110+
111+
new_params[key_tuple] = jnp.array(np_array)
112+
113+
print(f"Mapped {len(new_params)} out of {len(flat_params)} parameters.")
114+
115+
# Reconstruct nested dictionary
116+
params_nested = traverse_util.unflatten_dict(new_params)
80117

118+
# Save checkpoint
81119
print(f"Saving converted weights to {output_path}...")
82-
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
83-
save_args = orbax_utils.save_args_from_target(params)
84-
checkpointer.save(output_path, params, save_args=save_args)
120+
checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
121+
save_args = orbax_utils.save_args_from_target(params_nested)
122+
checkpointer.save(output_path, params_nested, save_args=save_args)
85123
print("Done!")
86124

87125
if __name__ == "__main__":

0 commit comments

Comments
 (0)