Skip to content

Commit 4f375aa

Browse files
committed
fix
1 parent 403e710 commit 4f375aa

2 files changed

Lines changed: 5 additions & 55 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,6 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
115115
# Also check 'weight' because rename_key might not have converted it to kernel if it wasn't a known Linear
116116
flax_key_str = [str(k) for k in flax_key]
117117

118-
# DEBUG: Check specific keys
119-
if "norm_k" in flax_key_str or "audio_caption_projection" in flax_key_str:
120-
print(f"DEBUG: get_key_and_value mapping: {pt_tuple_key} -> {flax_key_str}")
121-
122118
if flax_key_str[-1] in ["kernel", "weight"]:
123119
# Try replacing with scale and check if it exists in random_flax_state_dict
124120
temp_key_str = flax_key_str[:-1] + ["scale"]
@@ -298,47 +294,10 @@ def load_transformer_weights(
298294
string_tuple = tuple([str(item) for item in key])
299295
random_flax_state_dict[string_tuple] = flattened_dict[key]
300296

301-
# DEBUG: Print keys to understand mapping
302-
print("DEBUG: Top 20 keys from Checkpoint (tensors):")
303-
for k in list(tensors.keys())[:20]:
304-
print(k)
305-
306-
print("DEBUG: NON-BLOCK keys in Checkpoint:")
307-
for k in tensors.keys():
308-
if "transformer_blocks" not in k:
309-
print(k)
310-
311-
print("\nDEBUG: Top 20 keys from Flax Model (eval_shapes):")
312-
for k in list(random_flax_state_dict.keys())[:20]:
313-
print(k)
314-
315-
print("\nDEBUG: Transformer Block 0 keys from Checkpoint:")
316-
found_block_0 = False
317-
for k in tensors.keys():
318-
if "transformer_blocks.0." in k or "transformer_blocks_0." in k:
319-
print(k)
320-
found_block_0 = True
321-
322-
if not found_block_0:
323-
# Try looking for any block
324-
for k in tensors.keys():
325-
if "transformer_blocks" in k:
326-
print(f"Sample block key: {k}")
327-
break
328-
329-
print("\nDEBUG: Global Norm/LN candidates in Checkpoint:")
330-
for k in tensors.keys():
331-
if "norm" in k.lower() or "ln" in k.lower():
332-
if "transformer_blocks" not in k:
333-
print(k)
297+
for key in flattened_dict:
298+
string_tuple = tuple([str(item) for item in key])
299+
random_flax_state_dict[string_tuple] = flattened_dict[key]
334300

335-
print("\nDEBUG: Transformer Block keys from Flax Model (eval_shapes):")
336-
for k in list(random_flax_state_dict.keys()):
337-
k_str = str(k)
338-
if "transformer_blocks" in k_str and ("attn1" in k_str or "ff" in k_str):
339-
print(f"EVAL_SHAPE: {k}")
340-
pass
341-
342301
for pt_key, tensor in tensors.items():
343302
renamed_pt_key = rename_key(pt_key)
344303
renamed_pt_key = rename_for_ltx2_transformer(renamed_pt_key)
@@ -392,15 +351,6 @@ def load_vae_weights(
392351
cpu = jax.local_devices(backend="cpu")[0]
393352
flattened_eval = flatten_dict(eval_shapes)
394353

395-
# DEBUG: Print keys to understand mapping
396-
print("DEBUG: Top 20 keys from VAE Checkpoint (tensors):")
397-
for k in list(tensors.keys())[:20]:
398-
print(k)
399-
400-
flax_state_dict = {}
401-
cpu = jax.local_devices(backend="cpu")[0]
402-
flattened_eval = flatten_dict(eval_shapes)
403-
404354
random_flax_state_dict = {}
405355
for key in flattened_eval:
406356
string_tuple = tuple([str(item) for item in key])

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def init_block(rngs):
807807
# 6. Output layers
808808
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
809809
self.norm_out = nnx.LayerNorm(
810-
inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
810+
inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
811811
)
812812
self.proj_out = nnx.Linear(
813813
inner_dim,
@@ -820,7 +820,7 @@ def init_block(rngs):
820820
)
821821

822822
self.audio_norm_out = nnx.LayerNorm(
823-
audio_inner_dim, epsilon=1e-6, use_scale=False, use_bias=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
823+
audio_inner_dim, epsilon=1e-6, use_scale=False, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
824824
)
825825
self.audio_proj_out = nnx.Linear(
826826
audio_inner_dim,

0 commit comments

Comments
 (0)