Skip to content

Commit 1c0d221

Browse files
committed
transformer weight loading bug with scan layers = false
1 parent 0a4549b commit 1c0d221

2 files changed

Lines changed: 3 additions & 6 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,7 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
144144

145145
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
146146

147-
# Transpose back caption projections for LTX-2.3 as they are already in JAX format or shouldn't be transposed
148-
if ("caption_projection" in flax_key or "audio_caption_projection" in flax_key) and "timestep_embedder" not in flax_key:
149-
if "kernel" in flax_key and flax_tensor.ndim == 2:
150-
flax_tensor = flax_tensor.T
147+
151148

152149
flax_key_str = [str(k) for k in flax_key]
153150

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
124124
"audio_attention_head_dim": 64,
125125
"audio_cross_attention_dim": 4096,
126126
"num_layers": 48,
127-
"caption_channels": 8192,
128-
"audio_caption_channels": 4096,
127+
"caption_channels": 4096,
128+
"audio_caption_channels": 2048,
129129
}
130130
else:
131131
ltx2_config = LTX2VideoTransformer3DModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder)

0 commit comments

Comments
 (0)