Skip to content

Commit af91979

Browse files
committed
fix
1 parent 7c881e5 commit af91979

2 files changed

Lines changed: 21 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,17 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
112112
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
113113

114114
# Check if we got 'kernel' but expected 'scale' (common for scanned layers where shape check fails)
115+
# Also check 'weight' because rename_key might not have converted it to kernel if it wasn't a known Linear
115116
flax_key_str = [str(k) for k in flax_key]
116117

117-
if flax_key_str[-1] == "kernel":
118+
if flax_key_str[-1] in ["kernel", "weight"]:
118119
# Try replacing with scale and check if it exists in random_flax_state_dict
119120
temp_key_str = flax_key_str[:-1] + ["scale"]
120121
temp_key = tuple(temp_key_str) # Tuple of strings
121122

122123
if temp_key in random_flax_state_dict:
123124
flax_key_str = temp_key_str
125+
pass
124126

125127
# RESTORE LTX-2 specific keys that rename_key_and_reshape_tensor incorrectly maps to standard Flax names
126128
# Fix scale_shift_table mapping if it got 'kernel' appended
@@ -371,8 +373,15 @@ def load_vae_weights(
371373
# _tuple_str_to_int might not be needed if we already injected ints, but it's safe
372374
flax_key = _tuple_str_to_int(flax_key)
373375

374-
if flax_key == ("latents_mean",) or flax_key == ("latents_std",):
375-
continue # Skip stats
376+
flax_key = tuple(flax_key_str)
377+
flax_key = _tuple_str_to_int(flax_key)
378+
379+
# Allow latents_mean/std
380+
381+
# DEBUG
382+
if "conv" in flax_key_str or "bias" in flax_key_str:
383+
# print(f"DEBUG: VAE Key Map: {pt_tuple_key} -> {flax_key}")
384+
pass
376385

377386
if resnet_index is not None:
378387
if flax_key in flax_state_dict:

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,15 @@ def test_load_vae_weights(self):
134134
)
135135

136136
print("Validating VAE Weights...")
137-
validate_flax_state_dict(eval_shapes, loaded_weights)
137+
# Filter out dropout/rngs keys from eval_shapes as they are not expected in weights
138+
filtered_eval_shapes = {}
139+
for k, v in eval_shapes.items():
140+
k_str = [str(x) for x in k]
141+
if "dropout" in k_str or "rngs" in k_str:
142+
continue
143+
filtered_eval_shapes[k] = v
144+
145+
validate_flax_state_dict(filtered_eval_shapes, loaded_weights)
138146
print("VAE Weights Validated Successfully!")
139147

140148
if __name__ == "__main__":

0 commit comments

Comments
 (0)