Skip to content

Commit a100826

Browse files
committed
fix
1 parent 359daa8 commit a100826

1 file changed

Lines changed: 7 additions & 76 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 7 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,6 @@ def replace_suffix(lst, old, new):
144144
flax_key = tuple(flax_key_str)
145145
flax_key = _tuple_str_to_int(flax_key)
146146

147-
if "scale_shift_table" in str(flax_key):
148-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (scale_shift_table)")
149-
150-
if "audio_caption_projection" in str(flax_key):
151-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_caption_projection)")
152-
if "audio_time_embed" in str(flax_key):
153-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_time_embed)")
154-
155-
return flax_key, flax_tensor
156-
157147
if scan_layers and block_index is not None:
158148
if "transformer_blocks" in flax_key:
159149
if flax_key in flax_state_dict:
@@ -165,23 +155,6 @@ def replace_suffix(lst, old, new):
165155
new_tensor = new_tensor.at[block_index].set(flax_tensor)
166156
flax_tensor = new_tensor
167157

168-
# DEBUG TRACE
169-
if "audio_ff" in str(flax_key) and "kernel" in str(flax_key) and block_index == 18:
170-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (Block 18)")
171-
if "to_out" in str(flax_key) and "kernel" in str(flax_key) and block_index == 18 and "attn1" in str(flax_key):
172-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (Block 18 attn1)")
173-
174-
if "proj_in" in str(flax_key):
175-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (proj_in)")
176-
177-
if "scale_shift_table" in str(flax_key):
178-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (scale_shift_table)")
179-
180-
if "audio_caption_projection" in str(flax_key):
181-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_caption_projection)")
182-
if "audio_time_embed" in str(flax_key):
183-
print(f"DEBUG: Mapped {pt_tuple_key} -> {flax_key} (audio_time_embed)")
184-
185158
return flax_key, flax_tensor
186159

187160
def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
@@ -248,57 +221,17 @@ def load_transformer_weights(
248221
for key in flattened_dict:
249222
string_tuple = tuple([str(item) for item in key])
250223
random_flax_state_dict[string_tuple] = flattened_dict[key]
251-
252-
# DEBUG: Print keys to understand mapping
253-
print("DEBUG: Top 20 keys from Checkpoint (tensors):")
254-
for k in list(tensors.keys())[:20]:
255-
print(k)
256-
257-
print("DEBUG: NON-BLOCK keys in Checkpoint:")
258-
for k in tensors.keys():
259-
if "transformer_blocks" not in k:
260-
print(k)
261-
262-
263-
print("\nDEBUG: Top 20 keys from Flax Model (eval_shapes):")
264-
for k in list(random_flax_state_dict.keys())[:20]:
265-
print(k)
266-
267-
print("\nDEBUG: Transformer Block keys from Flax Model (eval_shapes):")
268-
for k in list(random_flax_state_dict.keys()):
269-
k_str = str(k)
270-
if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str):
271-
print(f"EVAL_SHAPE: {k}")
272-
if "proj_out" in k_str or "norm_out" in k_str:
273-
print(f"EVAL_SHAPE GLOBAL: {k}")
274-
275-
# Search for norm in tensors
276-
print("\nDEBUG: Search 'norm' in checkpoint keys:")
277-
for k in tensors.keys():
278-
if "norm" in k and "transformer_blocks" not in k:
279-
print(f"CKPT norm: {k}")
280224

281225
for pt_key, tensor in tensors.items():
282226
renamed_pt_key = rename_key(pt_key)
283227
renamed_pt_key = rename_for_ltx2_transformer(renamed_pt_key)
284228

285-
# DEBUG: Check intermediate rename
286-
if "audio_ff.net.0.proj" in pt_key:
287-
pass
288-
289229
pt_tuple_key = tuple(renamed_pt_key.split("."))
290230

291231
flax_key, flax_tensor = get_key_and_value(
292232
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
293233
)
294234

295-
# DEBUG: Trace proj_out
296-
if "proj_out" in str(flax_key) and "bias" in str(flax_key):
297-
print(f"DEBUG: Trace proj_out: {pt_key} -> {flax_key}")
298-
# Check if added to dict
299-
# It acts global so it should be added below
300-
301-
302235
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
303236

304237
validate_flax_state_dict(eval_shapes, flax_state_dict)
@@ -337,10 +270,6 @@ def load_vae_weights(
337270
loaded_state_dict = torch.load(ckpt_path, map_location="cpu")
338271
for k, v in loaded_state_dict.items():
339272
tensors[k] = torch2jax(v)
340-
341-
print("\nDEBUG: Top 20 keys from VAE Checkpoint (tensors):")
342-
for k in list(tensors.keys())[:20]:
343-
print(k)
344273

345274
flax_state_dict = {}
346275
cpu = jax.local_devices(backend="cpu")[0]
@@ -375,16 +304,18 @@ def load_vae_weights(
375304
pt_list.append(part)
376305
elif part in ["conv1", "conv2", "conv"]:
377306
pt_list.append(part)
378-
# Only inject 'conv' if it's not already there
379-
# Check if next part is 'conv'
307+
# Inject 'conv' if it's not already there AND not just added
380308
if i + 1 < len(pt_tuple_key) and pt_tuple_key[i+1] == "conv":
381309
pass # already has conv
382-
elif pt_list[-2] == "conv": # Check previous injection
383-
pass # already injected conv in previous step (if part was conv1/conv2/conv)
384-
# Also avoid injecting if part ITSELF is 'conv'
310+
elif pt_list[-1] == "conv":
311+
pass # already has conv
385312
elif part == "conv":
313+
# It IS conv, so we appended it. Do we need another one?
314+
# If part is 'conv', we appended it.
315+
# The original logic skipped it. We kept it.
386316
pass
387317
else:
318+
# If part is conv1/conv2, append 'conv'
388319
pt_list.append("conv")
389320
else:
390321
pt_list.append(part)

0 commit comments

Comments
 (0)