Skip to content

Commit 91600bf

Browse files
committed
connectors debug
1 parent 58347cd commit 91600bf

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ def rename_for_ltx2_3_vocoder(key):
9696
"model.diffusion_model.": "",
9797
"connectors.": "",
9898
"transformer_1d_blocks": "stacked_blocks",
99+
"video_embeddings_connector": "video_connector",
100+
"audio_embeddings_connector": "audio_connector",
101+
"ff.net.0.proj.weight": "ff.net_0.kernel",
102+
"ff.net.0.proj.bias": "ff.net_0.bias",
103+
"ff.net.2.weight": "ff.net_2.kernel",
104+
"ff.net.2.bias": "ff.net_2.bias",
99105
"text_embedding_projection.audio_aggregate_embed.weight": "audio_text_proj_in.kernel",
100106
"text_embedding_projection.audio_aggregate_embed.bias": "audio_text_proj_in.bias",
101107
"text_embedding_projection.video_aggregate_embed.weight": "video_text_proj_in.kernel",
@@ -392,6 +398,8 @@ def load_vocoder_weights_2_3(
392398

393399
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
394400

401+
print(f"DEBUG Vocoder eval_shapes keys: {list(flatten_dict(eval_shapes).keys())[:20]}")
402+
print(f"DEBUG Vocoder flax_state_dict keys: {list(flax_state_dict.keys())[:20]}")
395403
validate_flax_state_dict(eval_shapes, flax_state_dict)
396404
return unflatten_dict(flax_state_dict)
397405

0 commit comments

Comments
 (0)