File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -243,6 +243,8 @@ def load_transformer_weights(
243243 if not pt_key .startswith ("model.diffusion_model." ):
244244 continue
245245 pt_key = pt_key .replace ("model.diffusion_model." , "" )
246+ if pt_key .startswith ("audio_embeddings_connector" ) or pt_key .startswith ("video_embeddings_connector" ):
247+ continue
246248
247249 renamed_pt_key = rename_key (pt_key )
248250 renamed_pt_key = rename_for_ltx2_transformer (renamed_pt_key )
Original file line number Diff line number Diff line change @@ -707,18 +707,19 @@ def __init__(
707707 weights_dtype = self .weights_dtype ,
708708 )
709709 # 3. Timestep Modulation Params and Embedding
710+ num_mod_params = 9 if self .cross_attn_mod else 6
710711 self .time_embed = LTX2AdaLayerNormSingle (
711712 rngs = rngs ,
712713 embedding_dim = inner_dim ,
713- num_mod_params = 6 ,
714+ num_mod_params = num_mod_params ,
714715 use_additional_conditions = False ,
715716 dtype = self .dtype ,
716717 weights_dtype = self .weights_dtype ,
717718 )
718719 self .audio_time_embed = LTX2AdaLayerNormSingle (
719720 rngs = rngs ,
720721 embedding_dim = audio_inner_dim ,
721- num_mod_params = 6 ,
722+ num_mod_params = num_mod_params ,
722723 use_additional_conditions = False ,
723724 dtype = self .dtype ,
724725 weights_dtype = self .weights_dtype ,
You can’t perform that action at this time.
0 commit comments