Skip to content

Commit b9ac9eb

Browse files
committed
transformer weight
1 parent f7e5102 commit b9ac9eb

2 files changed

Lines changed: 7 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def rename_for_ltx2_transformer(key):
9898
# Add missing mappings
9999
key = key.replace("av_ca_video_scale_shift_adaln_single", "av_cross_attn_video_scale_shift")
100100
key = key.replace("av_ca_a2v_gate_adaln_single", "av_cross_attn_video_a2v_gate")
101+
key = key.replace("adaln_single", "time_embed")
101102
key = key.replace("av_ca_audio_scale_shift_adaln_single", "av_cross_attn_audio_scale_shift")
102103
key = key.replace("av_ca_v2a_gate_adaln_single", "av_cross_attn_audio_v2a_gate")
103104
key = key.replace("scale_shift_table_a2v_ca_video", "video_a2v_cross_attn_scale_shift_table")

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,8 @@ def __init__(
612612
qk_norm: str = "rms_norm_across_heads",
613613
flash_block_sizes: BlockSizes = None,
614614
flash_min_seq_length: int = 4096,
615+
gated_attn: bool = False,
616+
cross_attn_mod: bool = False,
615617
**kwargs,
616618
):
617619
self.in_channels = in_channels
@@ -658,6 +660,8 @@ def __init__(
658660
self.names_which_can_be_offloaded = names_which_can_be_offloaded
659661
self.scan_layers = scan_layers
660662
self.attention_kernel = attention_kernel
663+
self.gated_attn = gated_attn
664+
self.cross_attn_mod = cross_attn_mod
661665
self.a2v_attention_kernel = a2v_attention_kernel
662666
self.v2a_attention_kernel = v2a_attention_kernel
663667
self.flash_min_seq_length = flash_min_seq_length
@@ -845,6 +849,8 @@ def init_block(rngs):
845849
norm_elementwise_affine=self.norm_elementwise_affine,
846850
norm_eps=self.norm_eps,
847851
rope_type=self.rope_type,
852+
gated_attn=self.gated_attn,
853+
cross_attn_mod=self.cross_attn_mod,
848854
dtype=self.dtype,
849855
weights_dtype=self.weights_dtype,
850856
mesh=self.mesh,

0 commit comments

Comments
 (0)