Skip to content

Commit 4646119

Browse files
committed
transformer weight loading
1 parent c66fd56 commit 4646119

1 file changed

Lines changed: 14 additions & 14 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
113113
# 1. Load config.
114114
if restored_checkpoint:
115115
ltx2_config = restored_checkpoint["ltx2_config"]
116-
elif getattr(config, "model_name", "") == "ltx2.3":
117-
ltx2_config = {
118-
"in_channels": 128,
119-
"num_attention_heads": 32,
120-
"attention_head_dim": 128,
121-
"cross_attention_dim": 4096,
122-
"audio_in_channels": 128,
123-
"audio_num_attention_heads": 32,
124-
"audio_attention_head_dim": 64,
125-
"audio_cross_attention_dim": 2048,
126-
"num_layers": 48,
127-
"caption_channels": 8192,
128-
"audio_caption_channels": 4096,
129-
}
116+
elif getattr(config, "model_name", "") == "ltx2.3":
117+
ltx2_config = {
118+
"in_channels": 128,
119+
"num_attention_heads": 32,
120+
"attention_head_dim": 128,
121+
"cross_attention_dim": 4096,
122+
"audio_in_channels": 128,
123+
"audio_num_attention_heads": 32,
124+
"audio_attention_head_dim": 64,
125+
"audio_cross_attention_dim": 2048,
126+
"num_layers": 48,
127+
"caption_channels": 8192,
128+
"audio_caption_channels": 4096,
129+
}
130130
else:
131131
ltx2_config = LTX2VideoTransformer3DModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder)
132132

0 commit comments

Comments
 (0)