Skip to content

Commit 406aadd

Browse files
committed
debug in modelling_flax_pytorch_utils.py
1 parent 7056046 commit 406aadd

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict):
3636
new_pytree: dict - a pytree that has been created from pytorch weights.
3737
"""
3838
expected_pytree = flatten_dict(expected_pytree)
39+
40+
# DEBUG PRINTS
41+
print(f"DEBUG: validate_flax_state_dict called.")
42+
print(f"DEBUG: expected_pytree keys: {len(expected_pytree)}")
43+
print(f"DEBUG: new_pytree keys: {len(new_pytree)}")
44+
45+
dropout_in_expected = [k for k in expected_pytree.keys() if "dropout" in str(k)]
46+
print(f"DEBUG: dropout keys in expected_pytree: {len(dropout_in_expected)}")
47+
48+
dropout_in_new = [k for k in new_pytree.keys() if "dropout" in str(k)]
49+
print(f"DEBUG: dropout keys in new_pytree: {len(dropout_in_new)}")
50+
3951
if len(expected_pytree.keys()) != len(new_pytree.keys()):
4052
set1 = set(expected_pytree.keys())
4153
set2 = set(new_pytree.keys())

0 commit comments

Comments
 (0)