Skip to content

Commit 80e5b52

Browse files
committed
fix
1 parent a100826 commit 80e5b52

1 file changed

Lines changed: 56 additions & 9 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,30 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
109109

110110
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
111111

112-
# RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
112+
# Check if we got 'kernel' but expected 'scale' (common for scanned layers where shape check fails)
113113
flax_key_str = [str(k) for k in flax_key]
114114

115+
if flax_key_str[-1] == "kernel":
116+
# Try replacing with scale and check if it exists in random_flax_state_dict
117+
temp_key_str = flax_key_str[:-1] + ["scale"]
118+
temp_key = tuple(temp_key_str) # Tuple of strings
119+
120+
# random_flax_state_dict keys are tuples of STRINGS
121+
if temp_key in random_flax_state_dict:
122+
flax_key_str = temp_key_str
123+
# If we are mapping weight -> scale, ensure tensor is 1D?
124+
# Linear weights are 2D (transposed). Scale weights are 1D.
125+
# If input tensor was 1D, rename_key_and_reshape_tensor converts it to 1D?
126+
# No, if it thought it was Linear, it might have transposed (if 2D) or whatever.
127+
# But if it was originally 1D 'weight' (like LayerNorm), rename_key_and_reshape_tensor (Linear logic)
128+
# checks `if pt_tuple_key[-1] == "weight"`.
129+
# Linear logic: `pt_tensor = pt_tensor.T`.
130+
# If 1D, T is same. So harmless for 1D.
131+
pass
132+
133+
# RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
134+
# flax_key_str = [str(k) for k in flax_key] # Already have it
135+
115136
# Fix scale_shift_table mapping if it got 'kernel' appended
116137
if "scale_shift_table" in flax_key_str:
117138
# if last is kernel/weight, remove it
@@ -217,10 +238,25 @@ def load_transformer_weights(
217238
cpu = jax.local_devices(backend="cpu")[0]
218239
flattened_dict = flatten_dict(eval_shapes)
219240

220-
random_flax_state_dict = {}
221-
for key in flattened_dict:
222-
string_tuple = tuple([str(item) for item in key])
223-
random_flax_state_dict[string_tuple] = flattened_dict[key]
241+
# DEBUG: Print keys to understand mapping
242+
print("DEBUG: Top 20 keys from Checkpoint (tensors):")
243+
for k in list(tensors.keys())[:20]:
244+
print(k)
245+
246+
print("DEBUG: NON-BLOCK keys in Checkpoint:")
247+
for k in tensors.keys():
248+
if "transformer_blocks" not in k:
249+
print(k)
250+
251+
print("\nDEBUG: Top 20 keys from Flax Model (eval_shapes):")
252+
for k in list(random_flax_state_dict.keys())[:20]:
253+
print(k)
254+
255+
print("\nDEBUG: Transformer Block keys from Flax Model (eval_shapes):")
256+
for k in list(random_flax_state_dict.keys()):
257+
k_str = str(k)
258+
if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str):
259+
print(f"EVAL_SHAPE: {k}")
224260

225261
for pt_key, tensor in tensors.items():
226262
renamed_pt_key = rename_key(pt_key)
@@ -275,6 +311,15 @@ def load_vae_weights(
275311
cpu = jax.local_devices(backend="cpu")[0]
276312
flattened_eval = flatten_dict(eval_shapes)
277313

314+
# DEBUG: Print keys to understand mapping
315+
print("DEBUG: Top 20 keys from VAE Checkpoint (tensors):")
316+
for k in list(tensors.keys())[:20]:
317+
print(k)
318+
319+
flax_state_dict = {}
320+
cpu = jax.local_devices(backend="cpu")[0]
321+
flattened_eval = flatten_dict(eval_shapes)
322+
278323
random_flax_state_dict = {}
279324
for key in flattened_eval:
280325
string_tuple = tuple([str(item) for item in key])
@@ -302,17 +347,19 @@ def load_vae_weights(
302347
pt_list.append(str(idx))
303348
else:
304349
pt_list.append(part)
350+
elif part == "upsampler":
351+
pt_list.append("upsamplers")
352+
pt_list.append("0")
305353
elif part in ["conv1", "conv2", "conv"]:
306354
pt_list.append(part)
307355
# Inject 'conv' if it's not already there AND not just added
308356
if i + 1 < len(pt_tuple_key) and pt_tuple_key[i+1] == "conv":
309357
pass # already has conv
310358
elif pt_list[-1] == "conv":
311359
pass # already has conv
360+
elif len(pt_list) >= 2 and pt_list[-2] == "conv":
361+
pass
312362
elif part == "conv":
313-
# It IS conv, so we appended it. Do we need another one?
314-
# If part is 'conv', we appended it.
315-
# The original logic skipped it. We kept it.
316363
pass
317364
else:
318365
# If part is conv1/conv2, append 'conv'
@@ -342,7 +389,7 @@ def load_vae_weights(
342389
current_tensor = jnp.zeros(target_shape, dtype=flax_tensor.dtype)
343390
else:
344391
# Fallback if key missing (shouldn't happen with correct mapping)
345-
print(f"Warning: Key {str_flax_key} not found in random_flax_state_dict, cannot stack.")
392+
# print(f"Warning: Key {str_flax_key} not found in random_flax_state_dict, cannot stack.")
346393
current_tensor = flax_tensor # Might fail shape check later
347394

348395
# Place the tensor at the correct index

0 commit comments

Comments
 (0)