Skip to content

Commit c66fd56

Browse files
committed
transformer weight loading
1 parent 407b0c6 commit c66fd56

3 files changed

Lines changed: 20 additions & 14 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ def rename_for_ltx2_transformer(key):
117117
if "audio_text_proj_in" in key:
118118
key = key.replace("audio_text_proj_in", "feature_extractor.audio_linear")
119119

120+
key = key.replace("k_norm", "norm_k")
121+
key = key.replace("q_norm", "norm_q")
120122
key = key.replace("adaln_single", "time_embed")
121123
return key
122124

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ def __init__(
590590
norm_elementwise_affine: bool = False,
591591
norm_eps: float = 1e-6,
592592
caption_channels: int = 3840,
593+
audio_caption_channels: int = None,
593594
attention_bias: bool = True,
594595
attention_out_bias: bool = True,
595596
rope_theta: float = 10000.0,
@@ -643,6 +644,7 @@ def __init__(
643644
self.norm_elementwise_affine = norm_elementwise_affine
644645
self.norm_eps = norm_eps
645646
self.caption_channels = caption_channels
647+
self.audio_caption_channels = audio_caption_channels or caption_channels
646648
self.attention_bias = attention_bias
647649
self.attention_out_bias = attention_out_bias
648650
self.rope_theta = rope_theta
@@ -703,7 +705,7 @@ def __init__(
703705
)
704706
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
705707
rngs=rngs,
706-
in_features=self.caption_channels,
708+
in_features=self.audio_caption_channels,
707709
hidden_size=audio_inner_dim,
708710
embedding_dim=audio_inner_dim,
709711
dtype=self.dtype,
@@ -719,7 +721,7 @@ def __init__(
719721
)
720722
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
721723
rngs=rngs,
722-
in_features=self.caption_channels,
724+
in_features=self.audio_caption_channels,
723725
hidden_size=audio_inner_dim,
724726
dtype=self.dtype,
725727
weights_dtype=self.weights_dtype,

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,20 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
113113
# 1. Load config.
114114
if restored_checkpoint:
115115
ltx2_config = restored_checkpoint["ltx2_config"]
116-
elif getattr(config, "model_name", "") == "ltx2.3":
117-
ltx2_config = {
118-
"in_channels": 128,
119-
"num_attention_heads": 32,
120-
"attention_head_dim": 128,
121-
"cross_attention_dim": 4096,
122-
"audio_in_channels": 128,
123-
"audio_num_attention_heads": 32,
124-
"audio_attention_head_dim": 64,
125-
"audio_cross_attention_dim": 2048,
126-
"num_layers": 48,
127-
}
116+
elif getattr(config, "model_name", "") == "ltx2.3":
117+
ltx2_config = {
118+
"in_channels": 128,
119+
"num_attention_heads": 32,
120+
"attention_head_dim": 128,
121+
"cross_attention_dim": 4096,
122+
"audio_in_channels": 128,
123+
"audio_num_attention_heads": 32,
124+
"audio_attention_head_dim": 64,
125+
"audio_cross_attention_dim": 2048,
126+
"num_layers": 48,
127+
"caption_channels": 8192,
128+
"audio_caption_channels": 4096,
129+
}
128130
else:
129131
ltx2_config = LTX2VideoTransformer3DModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder)
130132

0 commit comments

Comments
 (0)