Skip to content

Commit 887eb8b

Browse files
committed
fix
1 parent 5fd6196 commit 887eb8b

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -266,16 +266,9 @@ def load_vae_weights(
266266
flattened_eval = flatten_dict(eval_shapes)
267267

268268
random_flax_state_dict = {}
269-
print(f"DEBUG: eval_shapes length (flattened): {len(flattened_eval)}")
270-
sample_target_found = False
271269
for key in flattened_eval:
272270
string_tuple = tuple([str(item) for item in key])
273271
random_flax_state_dict[string_tuple] = flattened_eval[key]
274-
275-
if not sample_target_found and "decoder" in string_tuple and "up_blocks" in string_tuple and "0" in string_tuple and "resnets" in string_tuple and "2" in string_tuple and "conv2" in string_tuple:
276-
print(f"DEBUG: Found target key in eval_shapes: {key}")
277-
print(f"DEBUG: Key types: {[type(x) for x in key]}")
278-
sample_target_found = True
279272

280273
for pt_key, tensor in tensors.items():
281274
renamed_pt_key = rename_key(pt_key)
@@ -347,11 +340,21 @@ def load_vae_weights(
347340
else:
348341
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
349342

350-
if "decoder" in flax_key_str and "up_blocks" in flax_key_str and "0" in flax_key_str and "resnets" in flax_key_str and "2" in flax_key_str and "conv2" in flax_key_str:
351-
print(f"DEBUG: Processing target key. Final flax_key: {flax_key}")
343+
344+
345+
# Filter out non-parameter keys for validation
346+
filtered_eval_shapes = {}
347+
for k, v in flattened_eval.items():
348+
# flax key is a tuple of strings/ints
349+
k_str = [str(x) for x in k]
350+
if "dropout" in k_str or "rngs" in k_str:
351+
continue
352+
filtered_eval_shapes[k] = v
352353

353354
print(f"Total VAE keys loaded: {len(flax_state_dict)}")
354-
validate_flax_state_dict(eval_shapes, flax_state_dict)
355+
356+
# Unflatten to pass to validate_flax_state_dict which expects a pytree
357+
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
355358
flax_state_dict = unflatten_dict(flax_state_dict)
356359
del tensors
357360
jax.clear_caches()

0 commit comments

Comments
 (0)