@@ -602,16 +602,40 @@ def load_audio_vae_weights(
602602 print (f"DEBUG: Initial eval_shapes count: { len (flattened_eval )} " )
603603 print (f"DEBUG: Filtered eval_shapes count: { len (filtered_eval_shapes )} " )
604604
605- # Check if any rngs remain
605+ # Check if any rngs remain in filtered
606606 rngs_count = 0
607607 for k in filtered_eval_shapes :
608608 k_str = [str (x ) for x in k ]
609609 for ks in k_str :
610610 if "rngs" in ks or "dropout" in ks :
611611 rngs_count += 1
612- print (f"DEBUG: Found unexpected rng/dropout key in filtered: { k } " )
613612 break
614- print (f"DEBUG: Remaining rngs/dropout keys: { rngs_count } " )
613+ print (f"DEBUG: Remaining rngs/dropout keys in Expected: { rngs_count } " )
614+
615+ # Check flax_state_dict for rngs (New)
616+ rngs_new_count = 0
617+ for k in flax_state_dict :
618+ k_str = [str (x ) for x in k ]
619+ for ks in k_str :
620+ if "rngs" in ks or "dropout" in ks :
621+ rngs_new_count += 1
622+ break
623+ print (f"DEBUG: rngs/dropout keys in New (loaded): { rngs_new_count } " )
624+
625+ # Explicit Set Diffs
626+ expected_keys = set (filtered_eval_shapes .keys ())
627+ new_keys = set (flax_state_dict .keys ())
628+
629+ missing_keys = expected_keys - new_keys
630+ extra_keys = new_keys - expected_keys
631+
632+ print (f"DEBUG: Truly Missing Keys (in Expected but not New): { len (missing_keys )} " )
633+ if len (missing_keys ) > 0 :
634+ print (f"DEBUG: Sample Missing: { list (missing_keys )[:5 ]} " )
635+
636+ print (f"DEBUG: Truly Extra Keys (in New but not Expected): { len (extra_keys )} " )
637+ if len (extra_keys ) > 0 :
638+ print (f"DEBUG: Sample Extra: { list (extra_keys )[:5 ]} " )
615639
616640 validate_flax_state_dict (unflatten_dict (filtered_eval_shapes ), flax_state_dict )
617641 return unflatten_dict (flax_state_dict )
0 commit comments