Skip to content

Commit 5fd6196

Browse files
committed
fix
1 parent c30e0ec commit 5fd6196

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def test_load_vae_weights(self):
122122
continue
123123
filtered_eval_shapes[k] = v
124124

125-
validate_flax_state_dict(filtered_eval_shapes, flatten_dict(loaded_weights))
125+
from flax.traverse_util import unflatten_dict
126+
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flatten_dict(loaded_weights))
126127
print("VAE Weights Validated Successfully!")
127128

128129
if __name__ == "__main__":

0 commit comments

Comments
 (0)