@@ -101,6 +101,7 @@ def __init__(
101101 rope_type : str = "interleaved" ,
102102 video_gated_attn : bool = False ,
103103 audio_gated_attn : bool = False ,
104+ cross_attn_mod : bool = False ,
104105 dtype : jnp .dtype = jnp .float32 ,
105106 weights_dtype : jnp .dtype = jnp .float32 ,
106107 mesh : jax .sharding .Mesh = None ,
@@ -320,14 +321,17 @@ def __init__(
320321 )
321322
322323 key = rngs .params ()
323- k1 , k2 , k3 , k4 = jax .random .split (key , 4 )
324+ k1 , k2 , k3 , k4 , k5 , k6 = jax .random .split (key , 6 )
324325
326+ # LTX 2.3 uses 9 parameters for scale/shift in blocks, LTX2 uses 6.
327+ num_block_scale_shift_params = 9 if cross_attn_mod else 6
328+
325329 self .scale_shift_table = nnx .Param (
326- jax .random .normal (k1 , (6 , self .dim ), dtype = weights_dtype ) / jnp .sqrt (self .dim ),
330+ jax .random .normal (k1 , (num_block_scale_shift_params , self .dim ), dtype = weights_dtype ) / jnp .sqrt (self .dim ),
327331 kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
328332 )
329333 self .audio_scale_shift_table = nnx .Param (
330- jax .random .normal (k2 , (6 , audio_dim ), dtype = weights_dtype ) / jnp .sqrt (audio_dim ),
334+ jax .random .normal (k2 , (num_block_scale_shift_params , audio_dim ), dtype = weights_dtype ) / jnp .sqrt (audio_dim ),
331335 kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
332336 )
333337 self .video_a2v_cross_attn_scale_shift_table = nnx .Param (
@@ -339,6 +343,16 @@ def __init__(
339343 kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
340344 )
341345
346+ if cross_attn_mod :
347+ self .prompt_scale_shift_table = nnx .Param (
348+ jax .random .normal (k5 , (2 , self .dim ), dtype = weights_dtype ) / jnp .sqrt (self .dim ),
349+ kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
350+ )
351+ self .audio_prompt_scale_shift_table = nnx .Param (
352+ jax .random .normal (k6 , (2 , audio_dim ), dtype = weights_dtype ) / jnp .sqrt (audio_dim ),
353+ kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
354+ )
355+
342356 def __call__ (
343357 self ,
344358 hidden_states : jax .Array , # Video
@@ -610,6 +624,9 @@ def __init__(
610624 qk_norm : str = "rms_norm_across_heads" ,
611625 flash_block_sizes : BlockSizes = None ,
612626 flash_min_seq_length : int = 4096 ,
627+ video_gated_attn : bool = False ,
628+ audio_gated_attn : bool = False ,
629+ cross_attn_mod : bool = False ,
613630 ** kwargs ,
614631 ):
615632 self .in_channels = in_channels
@@ -659,6 +676,9 @@ def __init__(
659676 self .a2v_attention_kernel = a2v_attention_kernel
660677 self .v2a_attention_kernel = v2a_attention_kernel
661678 self .flash_min_seq_length = flash_min_seq_length
679+ self .video_gated_attn = video_gated_attn
680+ self .audio_gated_attn = audio_gated_attn
681+ self .cross_attn_mod = cross_attn_mod
662682
663683 _out_channels = self .out_channels or self .in_channels
664684 _audio_out_channels = self .audio_out_channels or self .audio_in_channels
@@ -881,6 +901,9 @@ def init_block(rngs):
881901 v2a_attention_kernel = self .v2a_attention_kernel ,
882902 flash_block_sizes = flash_block_sizes ,
883903 flash_min_seq_length = self .flash_min_seq_length ,
904+ video_gated_attn = self .video_gated_attn ,
905+ audio_gated_attn = self .audio_gated_attn ,
906+ cross_attn_mod = self .cross_attn_mod ,
884907 )
885908
886909 if self .scan_layers :
@@ -916,6 +939,9 @@ def init_block(rngs):
916939 v2a_attention_kernel = self .v2a_attention_kernel ,
917940 flash_block_sizes = flash_block_sizes ,
918941 flash_min_seq_length = self .flash_min_seq_length ,
942+ video_gated_attn = self .video_gated_attn ,
943+ audio_gated_attn = self .audio_gated_attn ,
944+ cross_attn_mod = self .cross_attn_mod ,
919945 )
920946 blocks .append (block )
921947 self .transformer_blocks = nnx .List (blocks )
0 commit comments