@@ -112,15 +112,17 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
112112 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
113113
114114 # Check if we got 'kernel' but expected 'scale' (common for scanned layers where shape check fails)
115+ # Also check 'weight' because rename_key might not have converted it to kernel if it wasn't a known Linear
115116 flax_key_str = [str (k ) for k in flax_key ]
116117
117- if flax_key_str [- 1 ] == "kernel" :
118+ if flax_key_str [- 1 ] in [ "kernel" , "weight" ] :
118119 # Try replacing with scale and check if it exists in random_flax_state_dict
119120 temp_key_str = flax_key_str [:- 1 ] + ["scale" ]
120121 temp_key = tuple (temp_key_str ) # Tuple of strings
121122
122123 if temp_key in random_flax_state_dict :
123124 flax_key_str = temp_key_str
125+ pass
124126
125127 # RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
126128 # Fix scale_shift_table mapping if it got 'kernel' appended
@@ -371,8 +373,15 @@ def load_vae_weights(
371373 # _tuple_str_to_int might not be needed if we already injected ints, but it's safe
372374 flax_key = _tuple_str_to_int (flax_key )
373375
374- if flax_key == ("latents_mean" ,) or flax_key == ("latents_std" ,):
375- continue # Skip stats
376+ flax_key = tuple (flax_key_str )
377+ flax_key = _tuple_str_to_int (flax_key )
378+
379+ # Allow latents_mean/std
380+
381+ # DEBUG
382+ if "conv" in flax_key_str or "bias" in flax_key_str :
383+ # print(f"DEBUG: VAE Key Map: {pt_tuple_key} -> {flax_key}")
384+ pass
376385
377386 if resnet_index is not None :
378387 if flax_key in flax_state_dict :
0 commit comments