Skip to content

Commit b11bd47

Browse files
committed
fix
1 parent 406aadd commit b11bd47

3 files changed

Lines changed: 16 additions & 51 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,8 @@ def load_audio_vae_weights(
578578
# Map 0 -> 2, 1 -> 1, 2 -> 0
579579
new_stage_idx = 2 - stage_idx
580580
if "upsample" in flax_key:
581-
print(f"DEBUG REVERSAL: {flax_key} -> stage_idx={stage_idx} -> new={new_stage_idx}")
581+
# print(f"DEBUG REVERSAL: {flax_key} -> stage_idx={stage_idx} -> new={new_stage_idx}")
582+
pass
582583
flax_key_parts[up_stages_idx + 1] = new_stage_idx
583584
flax_key = tuple(flax_key_parts)
584585
except ValueError:
@@ -599,43 +600,5 @@ def load_audio_vae_weights(
599600
continue
600601
filtered_eval_shapes[k] = v
601602

602-
print(f"DEBUG: Initial eval_shapes count: {len(flattened_eval)}")
603-
print(f"DEBUG: Filtered eval_shapes count: {len(filtered_eval_shapes)}")
604-
605-
# Check if any rngs remain in filtered
606-
rngs_count = 0
607-
for k in filtered_eval_shapes:
608-
k_str = [str(x) for x in k]
609-
for ks in k_str:
610-
if "rngs" in ks or "dropout" in ks:
611-
rngs_count += 1
612-
break
613-
print(f"DEBUG: Remaining rngs/dropout keys in Expected: {rngs_count}")
614-
615-
# Check flax_state_dict for rngs (New)
616-
rngs_new_count = 0
617-
for k in flax_state_dict:
618-
k_str = [str(x) for x in k]
619-
for ks in k_str:
620-
if "rngs" in ks or "dropout" in ks:
621-
rngs_new_count += 1
622-
break
623-
print(f"DEBUG: rngs/dropout keys in New (loaded): {rngs_new_count}")
624-
625-
# Explicit Set Diffs
626-
expected_keys = set(filtered_eval_shapes.keys())
627-
new_keys = set(flax_state_dict.keys())
628-
629-
missing_keys = expected_keys - new_keys
630-
extra_keys = new_keys - expected_keys
631-
632-
print(f"DEBUG: Truly Missing Keys (in Expected but not New): {len(missing_keys)}")
633-
if len(missing_keys) > 0:
634-
print(f"DEBUG: Sample Missing: {list(missing_keys)[:5]}")
635-
636-
print(f"DEBUG: Truly Extra Keys (in New but not Expected): {len(extra_keys)}")
637-
if len(extra_keys) > 0:
638-
print(f"DEBUG: Sample Extra: {list(extra_keys)[:5]}")
639-
640603
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flax_state_dict)
641604
return unflatten_dict(flax_state_dict)

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,6 @@ def validate_flax_state_dict(expected_pytree: dict, new_pytree: dict):
3636
new_pytree: dict - a pytree that has been created from pytorch weights.
3737
"""
3838
expected_pytree = flatten_dict(expected_pytree)
39-
40-
# DEBUG PRINTS
41-
print(f"DEBUG: validate_flax_state_dict called.")
42-
print(f"DEBUG: expected_pytree keys: {len(expected_pytree)}")
43-
print(f"DEBUG: new_pytree keys: {len(new_pytree)}")
44-
45-
dropout_in_expected = [k for k in expected_pytree.keys() if "dropout" in str(k)]
46-
print(f"DEBUG: dropout keys in expected_pytree: {len(dropout_in_expected)}")
47-
48-
dropout_in_new = [k for k in new_pytree.keys() if "dropout" in str(k)]
49-
print(f"DEBUG: dropout keys in new_pytree: {len(dropout_in_new)}")
5039

5140
if len(expected_pytree.keys()) != len(new_pytree.keys()):
5241
set1 = set(expected_pytree.keys())

src/maxdiffusion/tests/test_ltx2_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,20 @@ def test_load_audio_vae_weights(self):
223223
)
224224

225225
print("Validating Audio VAE Weights...")
226-
validate_flax_state_dict(eval_shapes, flatten_dict(loaded_weights))
226+
# Filter eval_shapes for validation as load_audio_vae_weights returns filtered weights
227+
filtered_eval_shapes = {}
228+
flat_eval = flatten_dict(eval_shapes)
229+
for k, v in flat_eval.items():
230+
k_str = [str(x) for x in k]
231+
is_stat = False
232+
for ks in k_str:
233+
if "dropout" in ks or "rngs" in ks:
234+
is_stat = True
235+
break
236+
if not is_stat:
237+
filtered_eval_shapes[k] = v
238+
239+
validate_flax_state_dict(unflatten_dict(filtered_eval_shapes), flatten_dict(loaded_weights))
227240
print("Audio VAE Weights Validated Successfully!")
228241

229242
if __name__ == "__main__":

0 commit comments

Comments
 (0)