@@ -39,6 +39,8 @@ class LTX2AudioVideoGemmaTextEncoder(nnx.Module, FlaxModelMixin, ConfigMixin):
3939 def __init__ (
4040 self ,
4141 caption_channels : int = 3840 ,
42+ video_caption_channels : Optional [int ] = None ,
43+ audio_caption_channels : Optional [int ] = None ,
4244 text_proj_in_factor : int = 49 ,
4345 video_connector_attention_head_dim : int = 128 ,
4446 video_connector_num_attention_heads : int = 30 ,
@@ -57,63 +59,151 @@ def __init__(
5759 attention_kernel : str = "flash" ,
5860 mesh : jax .sharding .Mesh = None ,
5961 rngs : nnx .Rngs = None ,
62+ per_modality_projections : bool = False ,
63+ proj_bias : bool = False ,
64+ video_gated_attn : bool = False ,
65+ audio_gated_attn : bool = False ,
66+ audio_hidden_dim : Optional [int ] = None ,
67+ video_hidden_dim : Optional [int ] = None ,
6068 ** kwargs ,
6169 ):
62- input_dim = caption_channels * text_proj_in_factor
63-
64- self .feature_extractor = LTX2GemmaFeatureExtractor (
65- input_dim = input_dim ,
66- output_dim = caption_channels ,
67- dtype = dtype ,
68- rngs = rngs ,
69- )
70-
71- # Two independent connectors
72- self .video_embeddings_connector = Embeddings1DConnector (
73- input_dim = caption_channels ,
74- heads = video_connector_num_attention_heads ,
75- head_dim = video_connector_attention_head_dim ,
76- layers = video_connector_num_layers ,
77- num_learnable_registers = video_connector_num_learnable_registers ,
78- rope_type = rope_type ,
79- theta = rope_theta ,
80- base_seq_len = connector_rope_base_seq_len ,
81- double_precision = rope_double_precision ,
82- attention_kernel = attention_kernel ,
83- mesh = mesh ,
84- rngs = rngs ,
85- )
86-
87- self .audio_embeddings_connector = Embeddings1DConnector (
88- input_dim = caption_channels ,
89- heads = audio_connector_num_attention_heads ,
90- head_dim = audio_connector_attention_head_dim ,
91- layers = audio_connector_num_layers ,
92- num_learnable_registers = audio_connector_num_learnable_registers ,
93- rope_type = rope_type ,
94- theta = rope_theta ,
95- base_seq_len = connector_rope_base_seq_len ,
96- double_precision = rope_double_precision ,
97- attention_kernel = attention_kernel ,
98- mesh = mesh ,
99- rngs = rngs ,
100- )
70+ gemma_dim = 3840 if video_caption_channels is not None else caption_channels
71+ input_dim = gemma_dim * text_proj_in_factor
72+
73+ v_dim = video_hidden_dim if video_hidden_dim is not None else (video_caption_channels if video_caption_channels is not None else caption_channels )
74+ a_dim = audio_hidden_dim if audio_hidden_dim is not None else (audio_caption_channels if audio_caption_channels is not None else caption_channels )
75+
76+ self .per_modality_projections = per_modality_projections
77+
78+ if per_modality_projections :
79+ self .video_text_proj_in = nnx .Linear (
80+ in_features = input_dim , out_features = v_dim , use_bias = proj_bias , rngs = rngs
81+ )
82+ self .audio_text_proj_in = nnx .Linear (
83+ in_features = input_dim , out_features = a_dim , use_bias = proj_bias , rngs = rngs
84+ )
85+
86+ self .video_embeddings_connector = Embeddings1DConnector (
87+ input_dim = v_dim ,
88+ heads = video_connector_num_attention_heads ,
89+ head_dim = video_connector_attention_head_dim ,
90+ layers = video_connector_num_layers ,
91+ num_learnable_registers = video_connector_num_learnable_registers ,
92+ rope_type = rope_type ,
93+ theta = rope_theta ,
94+ base_seq_len = connector_rope_base_seq_len ,
95+ double_precision = rope_double_precision ,
96+ attention_kernel = attention_kernel ,
97+ mesh = mesh ,
98+ rngs = rngs ,
99+ gated_attn = video_gated_attn ,
100+ )
101+ self .audio_embeddings_connector = Embeddings1DConnector (
102+ input_dim = a_dim ,
103+ heads = audio_connector_num_attention_heads ,
104+ head_dim = audio_connector_attention_head_dim ,
105+ layers = audio_connector_num_layers ,
106+ num_learnable_registers = audio_connector_num_learnable_registers ,
107+ rope_type = rope_type ,
108+ theta = rope_theta ,
109+ base_seq_len = connector_rope_base_seq_len ,
110+ double_precision = rope_double_precision ,
111+ attention_kernel = attention_kernel ,
112+ mesh = mesh ,
113+ rngs = rngs ,
114+ gated_attn = audio_gated_attn ,
115+ )
116+ else :
117+ self .feature_extractor = LTX2GemmaFeatureExtractor (
118+ input_dim = input_dim ,
119+ output_dim = caption_channels ,
120+ dtype = dtype ,
121+ rngs = rngs ,
122+ per_modality_projections = per_modality_projections ,
123+ use_bias = proj_bias ,
124+ video_output_dim = v_dim ,
125+ audio_output_dim = a_dim ,
126+ )
127+
128+ # Two independent connectors
129+ self .video_embeddings_connector = Embeddings1DConnector (
130+ input_dim = v_dim ,
131+ heads = video_connector_num_attention_heads ,
132+ head_dim = video_connector_attention_head_dim ,
133+ layers = video_connector_num_layers ,
134+ num_learnable_registers = video_connector_num_learnable_registers ,
135+ rope_type = rope_type ,
136+ theta = rope_theta ,
137+ base_seq_len = connector_rope_base_seq_len ,
138+ double_precision = rope_double_precision ,
139+ attention_kernel = attention_kernel ,
140+ mesh = mesh ,
141+ rngs = rngs ,
142+ gated_attn = video_gated_attn ,
143+ )
144+ self .audio_embeddings_connector = Embeddings1DConnector (
145+ input_dim = a_dim ,
146+ heads = audio_connector_num_attention_heads ,
147+ head_dim = audio_connector_attention_head_dim ,
148+ layers = audio_connector_num_layers ,
149+ num_learnable_registers = audio_connector_num_learnable_registers ,
150+ rope_type = rope_type ,
151+ theta = rope_theta ,
152+ base_seq_len = connector_rope_base_seq_len ,
153+ double_precision = rope_double_precision ,
154+ attention_kernel = attention_kernel ,
155+ mesh = mesh ,
156+ rngs = rngs ,
157+ gated_attn = audio_gated_attn ,
158+ )
101159
102160 def __call__ (
103161 self ,
104162 hidden_states : Union [Tuple [Array , ...], List [Array ]],
105163 attention_mask : Array ,
106- ) -> Tuple [Array , Array ]:
164+ ) -> Tuple [Array , Array , Array ]:
107165 """
108166 Returns:
109167 (video_embeds, audio_embeds, new_attention_mask)
110168 """
111169 with jax .named_scope ("Text Encoder Forward" ):
112- # 1. Shared Feature Extraction
113- features = self .feature_extractor (hidden_states , attention_mask )
170+ if self .per_modality_projections :
171+ # 1. Stack Hidden States if needed
172+ if isinstance (hidden_states , (tuple , list )):
173+ x = jnp .stack (hidden_states , axis = - 1 )
174+ else :
175+ x = hidden_states
176+
177+ b , l , d , k = x .shape
178+
179+ # 2. Per-token RMS norm
180+ variance = jnp .mean (x ** 2 , axis = 2 , keepdims = True )
181+ norm_text_encoder_hidden_states = x * jax .lax .rsqrt (variance + 1e-6 )
182+
183+ norm_text_encoder_hidden_states = norm_text_encoder_hidden_states .reshape (b , l , - 1 )
184+
185+ bool_mask = (attention_mask > 0.5 ).astype (jnp .float32 )[..., None ]
186+ norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
187+
188+ # 3. Rescale norms
189+ cap_channels = getattr (self , "caption_channels" , getattr (self .config , "caption_channels" , 3840 ))
190+
191+ video_scale_factor = jnp .sqrt (self .video_embeddings_connector .dim / cap_channels )
192+ video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
193+ audio_scale_factor = jnp .sqrt (self .audio_embeddings_connector .dim / cap_channels )
194+ audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
195+
196+ video_text_emb_proj = self .video_text_proj_in (video_norm_text_emb )
197+ audio_text_emb_proj = self .audio_text_proj_in (audio_norm_text_emb )
198+
199+ video_embeds , new_attention_mask = self .video_embeddings_connector (video_text_emb_proj , attention_mask )
200+ audio_embeds , _ = self .audio_embeddings_connector (audio_text_emb_proj , attention_mask )
201+ else :
202+ # 1. Shared Feature Extraction
203+ features = self .feature_extractor (hidden_states , attention_mask )
114204
115- # 2. Parallel Connection
116- video_embeds , new_attention_mask = self .video_embeddings_connector (features , attention_mask )
117- audio_embeds , _ = self .audio_embeddings_connector (features , attention_mask )
205+ # 2. Parallel Connection
206+ video_embeds , new_attention_mask = self .video_embeddings_connector (features , attention_mask )
207+ audio_embeds , _ = self .audio_embeddings_connector (features , attention_mask )
118208
119209 return video_embeds , audio_embeds , new_attention_mask
0 commit comments