Skip to content

Commit ac9e51c

Browse files
committed
fixing conversion script
1 parent b6ae0f2 commit ac9e51c

1 file changed

Lines changed: 23 additions & 9 deletions

File tree

src/maxdiffusion/scripts/convert_ltx2_vae_weights.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from flax import traverse_util
1515

1616
def convert_ltx2_vae(hf_repo, output_path):
17+
# Ensure output path is absolute
18+
output_path = os.path.abspath(output_path)
19+
1720
# Load weights directly from Safetensors
1821
print(f"Downloading/Loading weights from {hf_repo}...")
1922
try:
@@ -38,14 +41,15 @@ def convert_ltx2_vae(hf_repo, output_path):
3841
out_channels=3,
3942
latent_channels=128,
4043
block_out_channels=(256, 512, 1024, 2048),
41-
decoder_block_out_channels=(2048, 1024, 512, 256),
44+
# Corrected Decoder Config based on PyTorch weights
45+
decoder_block_out_channels=(256, 512, 1024), # 3 blocks
4246
layers_per_block=(4, 6, 6, 2, 2),
43-
decoder_layers_per_block=(2, 2, 6, 6, 4),
47+
decoder_layers_per_block=(5, 5, 5, 5), # Mid + 3 Up, 5 layers each
4448
spatio_temporal_scaling=(True, True, True, True),
45-
decoder_spatio_temporal_scaling=(True, True, True, True),
46-
decoder_inject_noise=(False, False, False, False, False),
47-
upsample_factor=(2, 2, 2, 2),
48-
upsample_residual=(False, False, False, False),
49+
decoder_spatio_temporal_scaling=(True, True, True),
50+
decoder_inject_noise=(False, False, False, False),
51+
upsample_factor=(2, 2, 2),
52+
upsample_residual=(False, False, False),
4953
dtype=jnp.float32,
5054
rngs=nnx.Rngs(0)
5155
)
@@ -59,6 +63,7 @@ def convert_ltx2_vae(hf_repo, output_path):
5963

6064
new_params = {}
6165

66+
mapped_count = 0
6267
for key_tuple, value in flat_params.items():
6368
# Skip Rngs if any leak through
6469
if "rngs" in key_tuple or "count" in key_tuple or "key" in key_tuple:
@@ -82,11 +87,19 @@ def convert_ltx2_vae(hf_repo, output_path):
8287

8388
# Check if key exists in PT dict
8489
if pt_key not in pt_state_dict:
85-
print(f"Warning: {pt_key} not found in PyTorch state dict. Checking alternatives...")
90+
# Check for specific mismatches
91+
# Example: MaxDiffusion uses 'scale' for RMSNorm, PT uses 'weight'
92+
# (Handled above)
93+
print(f"Warning: {pt_key} not found in PyTorch state dict.")
8694
continue
8795

8896
pt_tensor = pt_state_dict[pt_key]
89-
np_array = pt_tensor.float().numpy()
97+
98+
# Handle BFloat16
99+
if pt_tensor.dtype == torch.bfloat16:
100+
pt_tensor = pt_tensor.float()
101+
102+
np_array = pt_tensor.numpy()
90103

91104
# Handle shape mismatch (Transpose Conv3d weights)
92105
is_conv_weight = "conv" in pt_key and "weight" in pt_key and len(np_array.shape) == 5
@@ -109,8 +122,9 @@ def convert_ltx2_vae(hf_repo, output_path):
109122
continue
110123

111124
new_params[key_tuple] = jnp.array(np_array)
125+
mapped_count += 1
112126

113-
print(f"Mapped {len(new_params)} out of {len(flat_params)} parameters.")
127+
print(f"Mapped {mapped_count} out of {len(flat_params)} parameters.")
114128

115129
# Reconstruct nested dictionary
116130
params_nested = traverse_util.unflatten_dict(new_params)

0 commit comments

Comments
 (0)