Skip to content

Commit bcf4a4b

Browse files
committed
transformer weight
1 parent b9ac9eb commit bcf4a4b

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def load_transformer_weights(
243243
if not pt_key.startswith("model.diffusion_model."):
244244
continue
245245
pt_key = pt_key.replace("model.diffusion_model.", "")
246+
if pt_key.startswith("audio_embeddings_connector") or pt_key.startswith("video_embeddings_connector"):
247+
continue
246248

247249
renamed_pt_key = rename_key(pt_key)
248250
renamed_pt_key = rename_for_ltx2_transformer(renamed_pt_key)

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,18 +707,19 @@ def __init__(
707707
weights_dtype=self.weights_dtype,
708708
)
709709
# 3. Timestep Modulation Params and Embedding
710+
num_mod_params = 9 if self.cross_attn_mod else 6
710711
self.time_embed = LTX2AdaLayerNormSingle(
711712
rngs=rngs,
712713
embedding_dim=inner_dim,
713-
num_mod_params=6,
714+
num_mod_params=num_mod_params,
714715
use_additional_conditions=False,
715716
dtype=self.dtype,
716717
weights_dtype=self.weights_dtype,
717718
)
718719
self.audio_time_embed = LTX2AdaLayerNormSingle(
719720
rngs=rngs,
720721
embedding_dim=audio_inner_dim,
721-
num_mod_params=6,
722+
num_mod_params=num_mod_params,
722723
use_additional_conditions=False,
723724
dtype=self.dtype,
724725
weights_dtype=self.weights_dtype,

0 commit comments

Comments
 (0)