@@ -578,7 +578,8 @@ def load_audio_vae_weights(
578578 # Map 0 -> 2, 1 -> 1, 2 -> 0
579579 new_stage_idx = 2 - stage_idx
580580 if "upsample" in flax_key :
581- print (f"DEBUG REVERSAL: { flax_key } -> stage_idx={ stage_idx } -> new={ new_stage_idx } " )
581+ # print(f"DEBUG REVERSAL: {flax_key} -> stage_idx={stage_idx} -> new={new_stage_idx}")
582+ pass
582583 flax_key_parts [up_stages_idx + 1 ] = new_stage_idx
583584 flax_key = tuple (flax_key_parts )
584585 except ValueError :
@@ -599,43 +600,5 @@ def load_audio_vae_weights(
599600 continue
600601 filtered_eval_shapes [k ] = v
601602
602- print (f"DEBUG: Initial eval_shapes count: { len (flattened_eval )} " )
603- print (f"DEBUG: Filtered eval_shapes count: { len (filtered_eval_shapes )} " )
604-
605- # Check if any rngs remain in filtered
606- rngs_count = 0
607- for k in filtered_eval_shapes :
608- k_str = [str (x ) for x in k ]
609- for ks in k_str :
610- if "rngs" in ks or "dropout" in ks :
611- rngs_count += 1
612- break
613- print (f"DEBUG: Remaining rngs/dropout keys in Expected: { rngs_count } " )
614-
615- # Check flax_state_dict for rngs (New)
616- rngs_new_count = 0
617- for k in flax_state_dict :
618- k_str = [str (x ) for x in k ]
619- for ks in k_str :
620- if "rngs" in ks or "dropout" in ks :
621- rngs_new_count += 1
622- break
623- print (f"DEBUG: rngs/dropout keys in New (loaded): { rngs_new_count } " )
624-
625- # Explicit Set Diffs
626- expected_keys = set (filtered_eval_shapes .keys ())
627- new_keys = set (flax_state_dict .keys ())
628-
629- missing_keys = expected_keys - new_keys
630- extra_keys = new_keys - expected_keys
631-
632- print (f"DEBUG: Truly Missing Keys (in Expected but not New): { len (missing_keys )} " )
633- if len (missing_keys ) > 0 :
634- print (f"DEBUG: Sample Missing: { list (missing_keys )[:5 ]} " )
635-
636- print (f"DEBUG: Truly Extra Keys (in New but not Expected): { len (extra_keys )} " )
637- if len (extra_keys ) > 0 :
638- print (f"DEBUG: Sample Extra: { list (extra_keys )[:5 ]} " )
639-
640603 validate_flax_state_dict (unflatten_dict (filtered_eval_shapes ), flax_state_dict )
641604 return unflatten_dict (flax_state_dict )
0 commit comments