Skip to content

Commit b02383b

Browse files
committed
attention head and audio cross attn head dim change
1 parent a4c45d6 commit b02383b

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
145145
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
146146

147147
# Transpose back caption projections for LTX-2.3 as they are already in JAX format or shouldn't be transposed
148-
if "caption_projection" in flax_key or "audio_caption_projection" in flax_key:
148+
if ("caption_projection" in flax_key or "audio_caption_projection" in flax_key) and "timestep_embedder" not in flax_key:
149149
if "kernel" in flax_key and flax_tensor.ndim == 2:
150150
flax_tensor = flax_tensor.T
151151

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
118118
"in_channels": 128,
119119
"num_attention_heads": 32,
120120
"attention_head_dim": 128,
121-
"cross_attention_dim": 4096,
121+
"cross_attention_dim": 8192,
122122
"audio_in_channels": 128,
123123
"audio_num_attention_heads": 32,
124124
"audio_attention_head_dim": 64,
125-
"audio_cross_attention_dim": 2048,
125+
"audio_cross_attention_dim": 4096,
126126
"num_layers": 48,
127127
"caption_channels": 8192,
128128
"audio_caption_channels": 4096,

0 commit comments

Comments
 (0)