Skip to content

Commit d79b6dd

Browse files
committed
transformer weights
1 parent 326e6be commit d79b6dd

1 file changed

Lines changed: 6 additions & 10 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def rename_for_ltx2_3_transformer(key):
6060
# Handle substrings before they are replaced by shorter patterns below
6161
key = key.replace("audio_prompt_adaln_single", "audio_prompt_adaln")
6262
key = key.replace("prompt_adaln_single", "prompt_adaln")
63-
key = key.replace("audio_prompt_scale_shift_table", "audio_scale_shift_table")
64-
key = key.replace("prompt_scale_shift_table", "scale_shift_table")
63+
# key = key.replace("audio_prompt_scale_shift_table", "audio_scale_shift_table")
64+
# key = key.replace("prompt_scale_shift_table", "scale_shift_table")
6565

66-
if "prompt_adaln" in key:
67-
key = key.replace("prompt_adaln", "caption_projection")
68-
if "audio_prompt_adaln" in key:
69-
key = key.replace("audio_prompt_adaln", "audio_caption_projection")
66+
# if "prompt_adaln" in key:
67+
# key = key.replace("prompt_adaln", "caption_projection")
68+
# if "audio_prompt_adaln" in key:
69+
# key = key.replace("audio_prompt_adaln", "audio_caption_projection")
7070
if "video_text_proj_in" in key:
7171
key = key.replace("video_text_proj_in", "feature_extractor.video_linear")
7272
if "audio_text_proj_in" in key:
@@ -398,8 +398,6 @@ def load_vocoder_weights_2_3(
398398

399399
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
400400

401-
print(f"DEBUG Vocoder eval_shapes keys: {list(flatten_dict(eval_shapes).keys())[:20]}")
402-
print(f"DEBUG Vocoder flax_state_dict keys: {list(flax_state_dict.keys())[:20]}")
403401
validate_flax_state_dict(eval_shapes, flax_state_dict)
404402
return unflatten_dict(flax_state_dict)
405403

@@ -481,8 +479,6 @@ def load_connectors_weights_2_3(
481479
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
482480
flax_state_dict[base_key] = jax.device_put(stacked_tensor, device=cpu)
483481

484-
print(f"DEBUG Connectors eval_shapes keys: {list(flattened_eval.keys())[:20]}")
485-
print(f"DEBUG Connectors flax_state_dict keys: {list(flax_state_dict.keys())[:20]}")
486482
filtered_eval_shapes = {
487483
k: v for k, v in flattened_eval.items() if not any("dropout" in str(x) or "rngs" in str(x) for x in k)
488484
}

0 commit comments

Comments
 (0)