@@ -109,9 +109,30 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
109109
110110 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict , scan_layers )
111111
112- # RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
112+ # Check if we got 'kernel' but expected 'scale' (common for scanned layers where shape check fails)
113113 flax_key_str = [str (k ) for k in flax_key ]
114114
115+ if flax_key_str [- 1 ] == "kernel" :
116+ # Try replacing with scale and check if it exists in random_flax_state_dict
117+ temp_key_str = flax_key_str [:- 1 ] + ["scale" ]
118+ temp_key = tuple (temp_key_str ) # Tuple of strings
119+
120+ # random_flax_state_dict keys are tuples of STRINGS
121+ if temp_key in random_flax_state_dict :
122+ flax_key_str = temp_key_str
123+ # If we are mapping weight -> scale, ensure tensor is 1D?
124+ # Linear weights are 2D (transposed). Scale weights are 1D.
125+ # If input tensor was 1D, rename_key_and_reshape_tensor converts it to 1D?
126+ # No, if it thought it was Linear, it might have transposed (if 2D) or whatever.
127+ # But if it was originally 1D 'weight' (like LayerNorm), rename_key_and_reshape_tensor (Linear logic)
128+ # checks `if pt_tuple_key[-1] == "weight"`.
129+ # Linear logic: `pt_tensor = pt_tensor.T`.
130+ # If 1D, T is same. So harmless for 1D.
131+ pass
132+
133+ # RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
134+ # flax_key_str = [str(k) for k in flax_key] # Already have it
135+
115136 # Fix scale_shift_table mapping if it got 'kernel' appended
116137 if "scale_shift_table" in flax_key_str :
117138 # if last is kernel/weight, remove it
@@ -217,10 +238,25 @@ def load_transformer_weights(
217238 cpu = jax .local_devices (backend = "cpu" )[0 ]
218239 flattened_dict = flatten_dict (eval_shapes )
219240
220- random_flax_state_dict = {}
221- for key in flattened_dict :
222- string_tuple = tuple ([str (item ) for item in key ])
223- random_flax_state_dict [string_tuple ] = flattened_dict [key ]
241+ # DEBUG: Print keys to understand mapping
242+ print ("DEBUG: Top 20 keys from Checkpoint (tensors):" )
243+ for k in list (tensors .keys ())[:20 ]:
244+ print (k )
245+
246+ print ("DEBUG: NON-BLOCK keys in Checkpoint:" )
247+ for k in tensors .keys ():
248+ if "transformer_blocks" not in k :
249+ print (k )
250+
251+ print ("\n DEBUG: Top 20 keys from Flax Model (eval_shapes):" )
252+ for k in list (random_flax_state_dict .keys ())[:20 ]:
253+ print (k )
254+
255+ print ("\n DEBUG: Transformer Block keys from Flax Model (eval_shapes):" )
256+ for k in list (random_flax_state_dict .keys ()):
257+ k_str = str (k )
258+ if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str ):
259+ print (f"EVAL_SHAPE: { k } " )
224260
225261 for pt_key , tensor in tensors .items ():
226262 renamed_pt_key = rename_key (pt_key )
@@ -275,6 +311,15 @@ def load_vae_weights(
275311 cpu = jax .local_devices (backend = "cpu" )[0 ]
276312 flattened_eval = flatten_dict (eval_shapes )
277313
314+ # DEBUG: Print keys to understand mapping
315+ print ("DEBUG: Top 20 keys from VAE Checkpoint (tensors):" )
316+ for k in list (tensors .keys ())[:20 ]:
317+ print (k )
318+
319+ flax_state_dict = {}
320+ cpu = jax .local_devices (backend = "cpu" )[0 ]
321+ flattened_eval = flatten_dict (eval_shapes )
322+
278323 random_flax_state_dict = {}
279324 for key in flattened_eval :
280325 string_tuple = tuple ([str (item ) for item in key ])
@@ -302,17 +347,19 @@ def load_vae_weights(
302347 pt_list .append (str (idx ))
303348 else :
304349 pt_list .append (part )
350+ elif part == "upsampler" :
351+ pt_list .append ("upsamplers" )
352+ pt_list .append ("0" )
305353 elif part in ["conv1" , "conv2" , "conv" ]:
306354 pt_list .append (part )
307355 # Inject 'conv' if it's not already there AND not just added
308356 if i + 1 < len (pt_tuple_key ) and pt_tuple_key [i + 1 ] == "conv" :
309357 pass # already has conv
310358 elif pt_list [- 1 ] == "conv" :
311359 pass # already has conv
360+ elif len (pt_list ) >= 2 and pt_list [- 2 ] == "conv" :
361+ pass
312362 elif part == "conv" :
313- # It IS conv, so we appended it. Do we need another one?
314- # If part is 'conv', we appended it.
315- # The original logic skipped it. We kept it.
316363 pass
317364 else :
318365 # If part is conv1/conv2, append 'conv'
@@ -342,7 +389,7 @@ def load_vae_weights(
342389 current_tensor = jnp .zeros (target_shape , dtype = flax_tensor .dtype )
343390 else :
344391 # Fallback if key missing (shouldn't happen with correct mapping)
345- print (f"Warning: Key { str_flax_key } not found in random_flax_state_dict, cannot stack." )
392+ # print(f"Warning: Key {str_flax_key} not found in random_flax_state_dict, cannot stack.")
346393 current_tensor = flax_tensor # Might fail shape check later
347394
348395 # Place the tensor at the correct index
0 commit comments