We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5b6fe91 commit 780b004Copy full SHA for 780b004
1 file changed
src/maxdiffusion/models/ltx2/ltx2_utils.py
@@ -599,5 +599,19 @@ def load_audio_vae_weights(
599
continue
600
filtered_eval_shapes[k] = v
601
602
+ print(f"DEBUG: Initial eval_shapes count: {len(flattened_eval)}")
603
+ print(f"DEBUG: Filtered eval_shapes count: {len(filtered_eval_shapes)}")
604
+
605
+ # Check if any rngs remain
606
+ rngs_count = 0
607
+ for k in filtered_eval_shapes:
608
+ k_str = [str(x) for x in k]
609
+ for ks in k_str:
610
+ if "rngs" in ks or "dropout" in ks:
611
+ rngs_count += 1
612
+ print(f"DEBUG: Found unexpected rng/dropout key in filtered: {k}")
613
+ break
614
+ print(f"DEBUG: Remaining rngs/dropout keys: {rngs_count}")
615
616
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
617
return unflatten_dict(flax_state_dict)
0 commit comments