Skip to content

Commit d7d7424

Browse files
committed
debug_audio_vae
1 parent 6d6e227 commit d7d7424

1 file changed

Lines changed: 52 additions & 6 deletions

File tree

debug_audio_vae.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,20 +212,66 @@ def flatten(d, parent_key=()):
212212
flax_keys_set = set(flax_keys)
213213
missing = flax_keys_set - final_keys
214214

215-
# Filter stats
215+
# Filter stats logic check
216+
print("\nDebugging Filtering Logic...")
216217
filtered_missing = []
218+
skipped_count = 0
217219
for k in missing:
218220
k_str = [str(x) for x in k]
219-
if "dropout" in k_str or "rngs" in k_str:
221+
is_stat = False
222+
for ks in k_str:
223+
if "dropout" in ks or "rngs" in ks:
224+
is_stat = True
225+
break
226+
if is_stat:
227+
skipped_count += 1
220228
continue
221229
filtered_missing.append(k)
222230

223-
print(f"Missing Keys (Count: {len(filtered_missing)}):")
224-
for k in sorted(filtered_missing)[:20]:
231+
print(f"Skipped {skipped_count} keys due to dropout/rngs filtering.")
232+
print(f"Remaining Missing Keys (Count: {len(filtered_missing)}):")
233+
for k in sorted(filtered_missing):
225234
print(k)
226235

227-
print("\nExtra Keys (Count: {len(final_keys - flax_keys_set)}):")
228-
for k in sorted(list(final_keys - flax_keys_set))[:20]:
236+
# Also check if validation function itself is behaving as expected
237+
from flax.traverse_util import unflatten_dict, flatten_dict
238+
from maxdiffusion.modeling_flax_pytorch_utils import validate_flax_state_dict
239+
240+
# Construct a dummy flax_state_dict with only the keys we found
241+
# We need to map our final_keys back to a dict
242+
# This is hard because we don't have the values here.
243+
# But we can check if the filtering removes the keys from eval_shapes
244+
245+
print("\nChecking if eval_shapes still has dropout keys after filtering:")
246+
filtered_eval_shapes = {}
247+
for k, v in eval_shapes.items(): # eval_shapes is already flattened if from to_pure_dict()?
248+
# Wait, to_pure_dict returns a nested dict or flat?
249+
# nnx.state(model).to_pure_dict() returns a nested dict structure usually compatible with unflatten_dict?
250+
# Let's check type of eval_shapes
251+
pass
252+
253+
# flatten_dict(eval_shapes)
254+
flat_eval = flatten_dict(eval_shapes)
255+
filtered_flat = {}
256+
for k, v in flat_eval.items():
257+
k_str = [str(x) for x in k]
258+
is_stat = False
259+
for ks in k_str:
260+
if "dropout" in ks or "rngs" in ks:
261+
is_stat = True
262+
break
263+
if is_stat:
264+
continue
265+
filtered_flat[k] = v
266+
267+
# Now check if the missing keys are in filtered_flat
268+
print(f"Filtered Flat Eval Shapes Count: {len(filtered_flat)}")
269+
print(f"Original Flat Eval Shapes Count: {len(flat_eval)}")
270+
271+
# Check if 'rngs' keys are in filtered_flat
272+
rngs_keys = [k for k in filtered_flat.keys() if "rngs" in str(k) or "dropout" in str(k)]
273+
print(f"Keys with 'rngs' or 'dropout' remaining in filtered dict: {len(rngs_keys)}")
274+
for k in rngs_keys[:10]:
229275
print(k)
230276

231277
if __name__ == "__main__":

0 commit comments

Comments
 (0)