Skip to content

Commit ddfa80c

Browse files
committed
fix
1 parent 1878eab commit ddfa80c

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,24 @@ def load_transformer_weights(
171171

172172
print("\nDEBUG: Transformer Block keys from Flax Model (eval_shapes):")
173173
for k in list(random_flax_state_dict.keys()):
174-
if "transformer_blocks" in k and "attn1" in k:
175-
print(k)
176-
break
174+
k_str = str(k)
175+
if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str):
176+
print(f"EVAL_SHAPE: {k}")
177+
if "proj_out" in k_str or "norm_out" in k_str:
178+
print(f"EVAL_SHAPE GLOBAL: {k}")
179+
180+
# Search for norm in tensors
181+
print("\nDEBUG: Search 'norm' in checkpoint keys:")
182+
for k in tensors.keys():
183+
if "norm" in k and "transformer_blocks" not in k:
184+
print(f"CKPT norm: {k}")
177185

178186
for pt_key, tensor in tensors.items():
179187
renamed_pt_key = rename_key(pt_key)
180188
renamed_pt_key = rename_for_ltx2_transformer(renamed_pt_key)
181189

182190
# DEBUG: Check intermediate rename
183191
if "audio_ff.net.0.proj" in pt_key:
184-
# This might spam, but good to see once
185192
pass
186193

187194
pt_tuple_key = tuple(renamed_pt_key.split("."))
@@ -190,6 +197,13 @@ def load_transformer_weights(
190197
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
191198
)
192199

200+
# DEBUG: Trace proj_out
201+
if "proj_out" in str(flax_key) and "bias" in str(flax_key):
202+
print(f"DEBUG: Trace proj_out: {pt_key} -> {flax_key}")
203+
# Check if added to dict
204+
# It acts global so it should be added below
205+
206+
193207
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
194208

195209
validate_flax_state_dict(eval_shapes, flax_state_dict)

0 commit comments

Comments
 (0)