@@ -212,20 +212,66 @@ def flatten(d, parent_key=()):
212212 flax_keys_set = set (flax_keys )
213213 missing = flax_keys_set - final_keys
214214
215- # Filter stats
215+ # Filter stats logic check
216+ print ("\n Debugging Filtering Logic..." )
216217 filtered_missing = []
218+ skipped_count = 0
217219 for k in missing :
218220 k_str = [str (x ) for x in k ]
219- if "dropout" in k_str or "rngs" in k_str :
221+ is_stat = False
222+ for ks in k_str :
223+ if "dropout" in ks or "rngs" in ks :
224+ is_stat = True
225+ break
226+ if is_stat :
227+ skipped_count += 1
220228 continue
221229 filtered_missing .append (k )
222230
223- print (f"Missing Keys (Count: { len (filtered_missing )} ):" )
224- for k in sorted (filtered_missing )[:20 ]:
231+ print (f"Skipped { skipped_count } keys due to dropout/rngs filtering." )
232+ print (f"Remaining Missing Keys (Count: { len (filtered_missing )} ):" )
233+ for k in sorted (filtered_missing ):
225234 print (k )
226235
227- print ("\n Extra Keys (Count: {len(final_keys - flax_keys_set)}):" )
228- for k in sorted (list (final_keys - flax_keys_set ))[:20 ]:
236+ # Also check if validation function itself is behaving as expected
237+ from flax .traverse_util import unflatten_dict , flatten_dict
238+ from maxdiffusion .modeling_flax_pytorch_utils import validate_flax_state_dict
239+
240+ # Construct a dummy flax_state_dict with only the keys we found
241+ # We need to map our final_keys back to a dict
242+ # This is hard because we don't have the values here.
243+ # But we can check if the filtering removes the keys from eval_shapes
244+
245+ print ("\n Checking if eval_shapes still has dropout keys after filtering:" )
246+ filtered_eval_shapes = {}
247+ for k , v in eval_shapes .items (): # eval_shapes is already flattened if from to_pure_dict()?
248+ # Wait, to_pure_dict returns a nested dict or flat?
249+ # nnx.state(model).to_pure_dict() returns a nested dict structure usually compatible with unflatten_dict?
250+ # Let's check type of eval_shapes
251+ pass
252+
253+ # flatten_dict(eval_shapes)
254+ flat_eval = flatten_dict (eval_shapes )
255+ filtered_flat = {}
256+ for k , v in flat_eval .items ():
257+ k_str = [str (x ) for x in k ]
258+ is_stat = False
259+ for ks in k_str :
260+ if "dropout" in ks or "rngs" in ks :
261+ is_stat = True
262+ break
263+ if is_stat :
264+ continue
265+ filtered_flat [k ] = v
266+
267+ # Now check if the missing keys are in filtered_flat
268+ print (f"Filtered Flat Eval Shapes Count: { len (filtered_flat )} " )
269+ print (f"Original Flat Eval Shapes Count: { len (flat_eval )} " )
270+
271+ # Check if 'rngs' keys are in filtered_flat
272+ rngs_keys = [k for k in filtered_flat .keys () if "rngs" in str (k ) or "dropout" in str (k )]
273+ print (f"Keys with 'rngs' or 'dropout' remaining in filtered dict: { len (rngs_keys )} " )
274+ for k in rngs_keys [:10 ]:
229275 print (k )
230276
231277if __name__ == "__main__" :
0 commit comments