@@ -99,6 +99,8 @@ def __init__(
9999 norm_elementwise_affine : bool = False ,
100100 norm_eps : float = 1e-6 ,
101101 rope_type : str = "interleaved" ,
102+ video_gated_attn : bool = False ,
103+ audio_gated_attn : bool = False ,
102104 dtype : jnp .dtype = jnp .float32 ,
103105 weights_dtype : jnp .dtype = jnp .float32 ,
104106 mesh : jax .sharding .Mesh = None ,
@@ -142,6 +144,7 @@ def __init__(
142144 rope_type = rope_type ,
143145 flash_block_sizes = flash_block_sizes ,
144146 flash_min_seq_length = flash_min_seq_length ,
147+ gated_attn = video_gated_attn ,
145148 )
146149
147150 self .audio_norm1 = nnx .RMSNorm (
@@ -168,6 +171,7 @@ def __init__(
168171 rope_type = rope_type ,
169172 flash_block_sizes = flash_block_sizes ,
170173 flash_min_seq_length = flash_min_seq_length ,
174+ gated_attn = audio_gated_attn ,
171175 )
172176
173177 # 2. Prompt Cross-Attention
@@ -195,6 +199,7 @@ def __init__(
195199 attention_kernel = self .attention_kernel ,
196200 rope_type = rope_type ,
197201 flash_block_sizes = flash_block_sizes ,
202+ gated_attn = video_gated_attn ,
198203 )
199204
200205 self .audio_norm2 = nnx .RMSNorm (
@@ -222,6 +227,7 @@ def __init__(
222227 rope_type = rope_type ,
223228 flash_block_sizes = flash_block_sizes ,
224229 flash_min_seq_length = flash_min_seq_length ,
230+ gated_attn = audio_gated_attn ,
225231 )
226232
227233 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -250,6 +256,7 @@ def __init__(
250256 rope_type = rope_type ,
251257 flash_block_sizes = flash_block_sizes ,
252258 flash_min_seq_length = 0 ,
259+ gated_attn = video_gated_attn ,
253260 )
254261
255262 self .video_to_audio_norm = nnx .RMSNorm (
@@ -277,6 +284,7 @@ def __init__(
277284 rope_type = rope_type ,
278285 flash_block_sizes = flash_block_sizes ,
279286 flash_min_seq_length = flash_min_seq_length ,
287+ gated_attn = audio_gated_attn ,
280288 )
281289
282290 # 4. Feed Forward
@@ -344,6 +352,9 @@ def __call__(
344352 temb_ca_audio_scale_shift : jax .Array ,
345353 temb_ca_gate : jax .Array ,
346354 temb_ca_audio_gate : jax .Array ,
355+ temb_prompt : Optional [jax .Array ] = None ,
356+ temb_prompt_audio : Optional [jax .Array ] = None ,
357+ modality_mask : Optional [jax .Array ] = None ,
347358 # RoPE
348359 video_rotary_emb : Optional [Tuple [jax .Array , jax .Array ]] = None ,
349360 audio_rotary_emb : Optional [Tuple [jax .Array , jax .Array ]] = None ,
@@ -354,6 +365,7 @@ def __call__(
354365 audio_encoder_attention_mask : Optional [jax .Array ] = None ,
355366 a2v_cross_attention_mask : Optional [jax .Array ] = None ,
356367 v2a_cross_attention_mask : Optional [jax .Array ] = None ,
368+ perturbation_mask : Optional [jax .Array ] = None ,
357369 ) -> Tuple [jax .Array , jax .Array ]:
358370 batch_size = hidden_states .shape [0 ]
359371
@@ -493,6 +505,10 @@ def __call__(
493505 k_rotary_emb = ca_audio_rotary_emb ,
494506 attention_mask = a2v_cross_attention_mask ,
495507 )
508+
509+ if perturbation_mask is not None :
510+ a2v_attn_hidden_states = mod_norm_audio_hidden_states + perturbation_mask * (a2v_attn_hidden_states - mod_norm_audio_hidden_states )
511+
496512 hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
497513
498514 # Video-to-Audio Cross Attention: Q: Audio; K,V: Video
@@ -509,6 +525,10 @@ def __call__(
509525 )
510526 audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
511527
528+ if modality_mask is not None :
529+ hidden_states = hidden_states * modality_mask
530+ audio_hidden_states = audio_hidden_states * modality_mask
531+
512532 # 4. Feedforward
513533 norm_hidden_states = self .norm3 (hidden_states )
514534 norm_hidden_states = norm_hidden_states * (1 + scale_mlp ) + shift_mlp
@@ -541,6 +561,8 @@ def __init__(
541561 pos_embed_max_pos : int = 20 ,
542562 base_height : int = 2048 ,
543563 base_width : int = 2048 ,
564+ gated_attn : bool = False ,
565+ cross_attn_mod : bool = False ,
544566 audio_in_channels : int = 128 , # Audio Arguments
545567 audio_out_channels : Optional [int ] = 128 ,
546568 audio_patch_size : int = 1 ,
@@ -552,11 +574,14 @@ def __init__(
552574 audio_pos_embed_max_pos : int = 20 ,
553575 audio_sampling_rate : int = 16000 ,
554576 audio_hop_length : int = 160 ,
577+ audio_gated_attn : bool = False ,
578+ audio_cross_attn_mod : bool = False ,
555579 num_layers : int = 48 , # Shared arguments
556580 activation_fn : str = "gelu" ,
557581 norm_elementwise_affine : bool = False ,
558582 norm_eps : float = 1e-6 ,
559583 caption_channels : int = 3840 ,
584+ audio_caption_channels : Optional [int ] = None ,
560585 attention_bias : bool = True ,
561586 attention_out_bias : bool = True ,
562587 rope_theta : float = 10000.0 ,
@@ -565,6 +590,8 @@ def __init__(
565590 timestep_scale_multiplier : int = 1000 ,
566591 cross_attn_timestep_scale_multiplier : int = 1000 ,
567592 rope_type : str = "interleaved" ,
593+ use_prompt_embeddings : bool = True ,
594+ perturbed_attn : bool = False ,
568595 dtype : jnp .dtype = jnp .float32 ,
569596 weights_dtype : jnp .dtype = jnp .float32 ,
570597 mesh : jax .sharding .Mesh = None ,
@@ -655,33 +682,59 @@ def __init__(
655682 )
656683
657684 # 2. Prompt embeddings
658- self .caption_projection = NNXPixArtAlphaTextProjection (
659- rngs = rngs ,
660- in_features = self .caption_channels ,
661- hidden_size = inner_dim ,
662- dtype = self .dtype ,
663- weights_dtype = self .weights_dtype ,
664- )
665- self .audio_caption_projection = NNXPixArtAlphaTextProjection (
666- rngs = rngs ,
667- in_features = self .caption_channels ,
668- hidden_size = audio_inner_dim ,
669- dtype = self .dtype ,
670- weights_dtype = self .weights_dtype ,
671- )
685+ if self .use_prompt_embeddings :
686+ self .caption_projection = NNXPixArtAlphaTextProjection (
687+ rngs = rngs ,
688+ in_features = self .caption_channels ,
689+ hidden_size = inner_dim ,
690+ dtype = self .dtype ,
691+ weights_dtype = self .weights_dtype ,
692+ )
693+ self .audio_caption_projection = NNXPixArtAlphaTextProjection (
694+ rngs = rngs ,
695+ in_features = self .caption_channels ,
696+ hidden_size = audio_inner_dim ,
697+ dtype = self .dtype ,
698+ weights_dtype = self .weights_dtype ,
699+ )
700+ else :
701+ self .caption_projection = None
702+ self .audio_caption_projection = None
703+
704+ if self .cross_attn_mod :
705+ self .prompt_adaln = LTX2AdaLayerNormSingle (
706+ rngs = rngs ,
707+ embedding_dim = inner_dim ,
708+ num_mod_params = 2 ,
709+ use_additional_conditions = False ,
710+ dtype = self .dtype ,
711+ weights_dtype = self .weights_dtype ,
712+ )
713+ self .audio_prompt_adaln = LTX2AdaLayerNormSingle (
714+ rngs = rngs ,
715+ embedding_dim = audio_inner_dim ,
716+ num_mod_params = 2 ,
717+ use_additional_conditions = False ,
718+ dtype = self .dtype ,
719+ weights_dtype = self .weights_dtype ,
720+ )
721+
672722 # 3. Timestep Modulation Params and Embedding
723+ video_time_emb_mod_params = 9 if cross_attn_mod else 6
724+ audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6
725+
673726 self .time_embed = LTX2AdaLayerNormSingle (
674727 rngs = rngs ,
675728 embedding_dim = inner_dim ,
676- num_mod_params = 6 ,
729+ num_mod_params = video_time_emb_mod_params ,
677730 use_additional_conditions = False ,
678731 dtype = self .dtype ,
679732 weights_dtype = self .weights_dtype ,
680733 )
681734 self .audio_time_embed = LTX2AdaLayerNormSingle (
682735 rngs = rngs ,
683736 embedding_dim = audio_inner_dim ,
684- num_mod_params = 6 ,
737+ num_mod_params = audio_time_emb_mod_params ,
685738 use_additional_conditions = False ,
686739 dtype = self .dtype ,
687740 weights_dtype = self .weights_dtype ,
@@ -722,11 +775,11 @@ def __init__(
722775 # 3. Output Layer Scale/Shift Modulation parameters
723776 param_rng = rngs .params ()
724777 self .scale_shift_table = nnx .Param (
725- jax .random .normal (param_rng , (2 , inner_dim ), dtype = self .weights_dtype ) / jnp .sqrt (inner_dim ),
778+ jax .random .normal (param_rng , (6 , inner_dim ), dtype = self .weights_dtype ) / jnp .sqrt (inner_dim ),
726779 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , "embed" )),
727780 )
728781 self .audio_scale_shift_table = nnx .Param (
729- jax .random .normal (param_rng , (2 , audio_inner_dim ), dtype = self .weights_dtype ) / jnp .sqrt (audio_inner_dim ),
782+ jax .random .normal (param_rng , (6 , audio_inner_dim ), dtype = self .weights_dtype ) / jnp .sqrt (audio_inner_dim ),
730783 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , "embed" )),
731784 )
732785
@@ -899,6 +952,8 @@ def __call__(
899952 audio_encoder_hidden_states : jax .Array ,
900953 timestep : jax .Array ,
901954 audio_timestep : Optional [jax .Array ] = None ,
955+ sigma : Optional [jax .Array ] = None ,
956+ audio_sigma : Optional [jax .Array ] = None ,
902957 encoder_attention_mask : Optional [jax .Array ] = None ,
903958 audio_encoder_attention_mask : Optional [jax .Array ] = None ,
904959 num_frames : Optional [int ] = None ,
@@ -909,7 +964,10 @@ def __call__(
909964 video_coords : Optional [jax .Array ] = None ,
910965 audio_coords : Optional [jax .Array ] = None ,
911966 attention_kwargs : Optional [Dict [str , Any ]] = None ,
967+ use_cross_timestep : bool = False ,
968+ modality_mask : Optional [jax .Array ] = None ,
912969 return_dict : bool = True ,
970+ perturbation_mask : Optional [jax .Array ] = None ,
913971 ) -> Any :
914972 # Determine timestep for audio.
915973 audio_timestep = audio_timestep if audio_timestep is not None else timestep
@@ -961,12 +1019,36 @@ def __call__(
9611019 temb_audio = temb_audio .reshape (batch_size , - 1 , temb_audio .shape [- 1 ])
9621020 audio_embedded_timestep = audio_embedded_timestep .reshape (batch_size , - 1 , audio_embedded_timestep .shape [- 1 ])
9631021
1022+ if self .cross_attn_mod and sigma is not None :
1023+ audio_sigma = audio_sigma if audio_sigma is not None else sigma
1024+ temb_prompt , _ = self .prompt_adaln (
1025+ sigma .flatten (),
1026+ hidden_dtype = hidden_states .dtype ,
1027+ )
1028+ temb_prompt_audio , _ = self .audio_prompt_adaln (
1029+ audio_sigma .flatten (),
1030+ hidden_dtype = audio_hidden_states .dtype ,
1031+ )
1032+ temb_prompt = temb_prompt .reshape (batch_size , - 1 , temb_prompt .shape [- 1 ])
1033+ temb_prompt_audio = temb_prompt_audio .reshape (batch_size , - 1 , temb_prompt_audio .shape [- 1 ])
1034+ else :
1035+ temb_prompt = None
1036+ temb_prompt_audio = None
1037+
1038+ if use_cross_timestep :
1039+ assert sigma is not None and audio_sigma is not None , "sigma and audio_sigma must be provided when use_cross_timestep is True"
1040+ video_ca_timestep = audio_sigma .flatten ()
1041+ audio_ca_timestep = sigma .flatten ()
1042+ else :
1043+ video_ca_timestep = timestep .flatten ()
1044+ audio_ca_timestep = audio_timestep .flatten () if audio_timestep is not None else timestep .flatten ()
1045+
9641046 video_cross_attn_scale_shift , _ = self .av_cross_attn_video_scale_shift (
965- timestep . flatten () ,
1047+ video_ca_timestep ,
9661048 hidden_dtype = hidden_states .dtype ,
9671049 )
9681050 video_cross_attn_a2v_gate , _ = self .av_cross_attn_video_a2v_gate (
969- timestep . flatten () * timestep_cross_attn_gate_scale_factor ,
1051+ video_ca_timestep * timestep_cross_attn_gate_scale_factor ,
9701052 hidden_dtype = hidden_states .dtype ,
9711053 )
9721054 video_cross_attn_scale_shift = video_cross_attn_scale_shift .reshape (
@@ -975,11 +1057,11 @@ def __call__(
9751057 video_cross_attn_a2v_gate = video_cross_attn_a2v_gate .reshape (batch_size , - 1 , video_cross_attn_a2v_gate .shape [- 1 ])
9761058
9771059 audio_cross_attn_scale_shift , _ = self .av_cross_attn_audio_scale_shift (
978- audio_timestep . flatten () ,
1060+ audio_ca_timestep ,
9791061 hidden_dtype = audio_hidden_states .dtype ,
9801062 )
9811063 audio_cross_attn_v2a_gate , _ = self .av_cross_attn_audio_v2a_gate (
982- audio_timestep . flatten () * timestep_cross_attn_gate_scale_factor ,
1064+ audio_ca_timestep * timestep_cross_attn_gate_scale_factor ,
9831065 hidden_dtype = audio_hidden_states .dtype ,
9841066 )
9851067 audio_cross_attn_scale_shift = audio_cross_attn_scale_shift .reshape (
@@ -988,14 +1070,26 @@ def __call__(
9881070 audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate .reshape (batch_size , - 1 , audio_cross_attn_v2a_gate .shape [- 1 ])
9891071
9901072 # 4. Prepare prompt embeddings
991- encoder_hidden_states = self .caption_projection (encoder_hidden_states )
992- encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
1073+ if self .use_prompt_embeddings and self .caption_projection is not None :
1074+ encoder_hidden_states = self .caption_projection (encoder_hidden_states )
1075+ audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
9931076
994- audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
995- audio_encoder_hidden_states = audio_encoder_hidden_states .reshape (batch_size , - 1 , audio_hidden_states .shape [- 1 ])
1077+ encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
1078+ audio_encoder_hidden_states = audio_encoder_hidden_states .reshape (batch_size , - 1 , audio_hidden_states .shape [- 1 ])
1079+
1080+ # Construct perturbation_mask_per_layer for STG
1081+ if perturbation_mask is None :
1082+ perturbation_mask_per_layer = jnp .ones ((self .num_layers , batch_size , 1 , 1 ), dtype = self .dtype )
1083+ else :
1084+ masks = jnp .ones ((self .num_layers , batch_size , 1 , 1 ), dtype = self .dtype )
1085+ for i in self .spatio_temporal_guidance_blocks :
1086+ if i < self .num_layers :
1087+ masks = masks .at [i ].set (perturbation_mask )
1088+ perturbation_mask_per_layer = masks
9961089
9971090 # 5. Run transformer blocks
998- def scan_fn (carry , block ):
1091+ def scan_fn (carry , block_and_mask ):
1092+ block , mask = block_and_mask
9991093 hidden_states , audio_hidden_states , rngs_carry = carry
10001094 with jax .named_scope ("Transformer Layer" ):
10011095 hidden_states_out , audio_hidden_states_out = block (
@@ -1009,12 +1103,16 @@ def scan_fn(carry, block):
10091103 temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
10101104 temb_ca_gate = video_cross_attn_a2v_gate ,
10111105 temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1106+ temb_prompt = temb_prompt ,
1107+ temb_prompt_audio = temb_prompt_audio ,
10121108 video_rotary_emb = video_rotary_emb ,
10131109 audio_rotary_emb = audio_rotary_emb ,
10141110 ca_video_rotary_emb = video_cross_attn_rotary_emb ,
10151111 ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1016- encoder_attention_mask = encoder_attention_mask ,
1017- audio_encoder_attention_mask = audio_encoder_attention_mask ,
1112+ a2v_cross_attention_mask = encoder_attention_mask ,
1113+ v2a_cross_attention_mask = audio_encoder_attention_mask ,
1114+ perturbation_mask = mask ,
1115+ modality_mask = modality_mask ,
10181116 )
10191117 return (
10201118 hidden_states_out .astype (hidden_states .dtype ),
@@ -1034,9 +1132,10 @@ def scan_fn(carry, block):
10341132 in_axes = (nnx .Carry , 0 ),
10351133 out_axes = (nnx .Carry , 0 ),
10361134 transform_metadata = {nnx .PARTITION_NAME : "layers" },
1037- )(carry , self .transformer_blocks )
1135+ )(carry , ( self .transformer_blocks , perturbation_mask_per_layer ) )
10381136 else :
1039- for block in self .transformer_blocks :
1137+ for i , block in enumerate (self .transformer_blocks ):
1138+ mask = perturbation_mask_per_layer [i ] if perturbation_mask_per_layer is not None else None
10401139 hidden_states , audio_hidden_states = block (
10411140 hidden_states = hidden_states ,
10421141 audio_hidden_states = audio_hidden_states ,
@@ -1048,12 +1147,17 @@ def scan_fn(carry, block):
10481147 temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
10491148 temb_ca_gate = video_cross_attn_a2v_gate ,
10501149 temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1150+ temb_prompt = temb_prompt ,
1151+ temb_prompt_audio = temb_prompt_audio ,
10511152 video_rotary_emb = video_rotary_emb ,
10521153 audio_rotary_emb = audio_rotary_emb ,
10531154 ca_video_rotary_emb = video_cross_attn_rotary_emb ,
10541155 ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
10551156 encoder_attention_mask = encoder_attention_mask ,
10561157 audio_encoder_attention_mask = audio_encoder_attention_mask ,
1158+ a2v_cross_attention_mask = encoder_attention_mask ,
1159+ v2a_cross_attention_mask = audio_encoder_attention_mask ,
1160+ perturbation_mask = mask ,
10571161 )
10581162
10591163 # 6. Output layers
0 commit comments