1414from flax import traverse_util
1515
1616def convert_ltx2_vae (hf_repo , output_path ):
17+ # Ensure output path is absolute
18+ output_path = os .path .abspath (output_path )
19+
1720 # Load weights directly from Safetensors
1821 print (f"Downloading/Loading weights from { hf_repo } ..." )
1922 try :
@@ -38,14 +41,15 @@ def convert_ltx2_vae(hf_repo, output_path):
3841 out_channels = 3 ,
3942 latent_channels = 128 ,
4043 block_out_channels = (256 , 512 , 1024 , 2048 ),
41- decoder_block_out_channels = (2048 , 1024 , 512 , 256 ),
44+ # Corrected Decoder Config based on PyTorch weights
45+ decoder_block_out_channels = (256 , 512 , 1024 ), # 3 blocks
4246 layers_per_block = (4 , 6 , 6 , 2 , 2 ),
43- decoder_layers_per_block = (2 , 2 , 6 , 6 , 4 ),
47+ decoder_layers_per_block = (5 , 5 , 5 , 5 ), # Mid + 3 Up, 5 layers each
4448 spatio_temporal_scaling = (True , True , True , True ),
45- decoder_spatio_temporal_scaling = (True , True , True , True ),
46- decoder_inject_noise = (False , False , False , False , False ),
47- upsample_factor = (2 , 2 , 2 , 2 ),
48- upsample_residual = (False , False , False , False ),
49+ decoder_spatio_temporal_scaling = (True , True , True ),
50+ decoder_inject_noise = (False , False , False , False ),
51+ upsample_factor = (2 , 2 , 2 ),
52+ upsample_residual = (False , False , False ),
4953 dtype = jnp .float32 ,
5054 rngs = nnx .Rngs (0 )
5155 )
@@ -59,6 +63,7 @@ def convert_ltx2_vae(hf_repo, output_path):
5963
6064 new_params = {}
6165
66+ mapped_count = 0
6267 for key_tuple , value in flat_params .items ():
6368 # Skip Rngs if any leak through
6469 if "rngs" in key_tuple or "count" in key_tuple or "key" in key_tuple :
@@ -82,11 +87,19 @@ def convert_ltx2_vae(hf_repo, output_path):
8287
8388 # Check if key exists in PT dict
8489 if pt_key not in pt_state_dict :
85- print (f"Warning: { pt_key } not found in PyTorch state dict. Checking alternatives..." )
90+ # Check for specific mismatches
91+ # Example: MaxDiffusion uses 'scale' for RMSNorm, PT uses 'weight'
92+ # (Handled above)
93+ print (f"Warning: { pt_key } not found in PyTorch state dict." )
8694 continue
8795
8896 pt_tensor = pt_state_dict [pt_key ]
89- np_array = pt_tensor .float ().numpy ()
97+
98+ # Handle BFloat16
99+ if pt_tensor .dtype == torch .bfloat16 :
100+ pt_tensor = pt_tensor .float ()
101+
102+ np_array = pt_tensor .numpy ()
90103
91104 # Handle shape mismatch (Transpose Conv3d weights)
92105 is_conv_weight = "conv" in pt_key and "weight" in pt_key and len (np_array .shape ) == 5
@@ -109,8 +122,9 @@ def convert_ltx2_vae(hf_repo, output_path):
109122 continue
110123
111124 new_params [key_tuple ] = jnp .array (np_array )
125+ mapped_count += 1
112126
113- print (f"Mapped { len ( new_params ) } out of { len (flat_params )} parameters." )
127+ print (f"Mapped { mapped_count } out of { len (flat_params )} parameters." )
114128
115129 # Reconstruct nested dictionary
116130 params_nested = traverse_util .unflatten_dict (new_params )
0 commit comments