@@ -73,69 +73,135 @@ def __init__(
7373
7474 self .per_modality_projections = per_modality_projections
7575
76- self .feature_extractor = LTX2GemmaFeatureExtractor (
77- input_dim = input_dim ,
78- output_dim = caption_channels ,
79- dtype = dtype ,
80- rngs = rngs ,
81- per_modality_projections = per_modality_projections ,
82- use_bias = proj_bias ,
83- video_output_dim = v_dim ,
84- audio_output_dim = a_dim ,
85- )
86-
87- # Two independent connectors
88- self .video_embeddings_connector = Embeddings1DConnector (
89- input_dim = v_dim ,
90- heads = video_connector_num_attention_heads ,
91- head_dim = video_connector_attention_head_dim ,
92- layers = video_connector_num_layers ,
93- num_learnable_registers = video_connector_num_learnable_registers ,
94- rope_type = rope_type ,
95- theta = rope_theta ,
96- base_seq_len = connector_rope_base_seq_len ,
97- double_precision = rope_double_precision ,
98- attention_kernel = attention_kernel ,
99- mesh = mesh ,
100- rngs = rngs ,
101- gated_attn = video_gated_attn ,
102- )
103-
104- self .audio_embeddings_connector = Embeddings1DConnector (
105- input_dim = a_dim ,
106- heads = audio_connector_num_attention_heads ,
107- head_dim = audio_connector_attention_head_dim ,
108- layers = audio_connector_num_layers ,
109- num_learnable_registers = audio_connector_num_learnable_registers ,
110- rope_type = rope_type ,
111- theta = rope_theta ,
112- base_seq_len = connector_rope_base_seq_len ,
113- double_precision = rope_double_precision ,
114- attention_kernel = attention_kernel ,
115- mesh = mesh ,
116- rngs = rngs ,
117- gated_attn = audio_gated_attn ,
118- )
76+ if per_modality_projections :
77+ self .video_text_proj_in = nnx .Linear (
78+ in_features = input_dim , out_features = v_dim , use_bias = proj_bias , rngs = rngs
79+ )
80+ self .audio_text_proj_in = nnx .Linear (
81+ in_features = input_dim , out_features = a_dim , use_bias = proj_bias , rngs = rngs
82+ )
83+
84+ self .video_connector = Embeddings1DConnector (
85+ input_dim = v_dim ,
86+ heads = video_connector_num_attention_heads ,
87+ head_dim = video_connector_attention_head_dim ,
88+ layers = video_connector_num_layers ,
89+ num_learnable_registers = video_connector_num_learnable_registers ,
90+ rope_type = rope_type ,
91+ theta = rope_theta ,
92+ base_seq_len = connector_rope_base_seq_len ,
93+ double_precision = rope_double_precision ,
94+ attention_kernel = attention_kernel ,
95+ mesh = mesh ,
96+ rngs = rngs ,
97+ gated_attn = video_gated_attn ,
98+ )
99+ self .audio_connector = Embeddings1DConnector (
100+ input_dim = a_dim ,
101+ heads = audio_connector_num_attention_heads ,
102+ head_dim = audio_connector_attention_head_dim ,
103+ layers = audio_connector_num_layers ,
104+ num_learnable_registers = audio_connector_num_learnable_registers ,
105+ rope_type = rope_type ,
106+ theta = rope_theta ,
107+ base_seq_len = connector_rope_base_seq_len ,
108+ double_precision = rope_double_precision ,
109+ attention_kernel = attention_kernel ,
110+ mesh = mesh ,
111+ rngs = rngs ,
112+ gated_attn = audio_gated_attn ,
113+ )
114+ else :
115+ self .feature_extractor = LTX2GemmaFeatureExtractor (
116+ input_dim = input_dim ,
117+ output_dim = caption_channels ,
118+ dtype = dtype ,
119+ rngs = rngs ,
120+ per_modality_projections = per_modality_projections ,
121+ use_bias = proj_bias ,
122+ video_output_dim = v_dim ,
123+ audio_output_dim = a_dim ,
124+ )
125+
126+ # Two independent connectors
127+ self .video_embeddings_connector = Embeddings1DConnector (
128+ input_dim = v_dim ,
129+ heads = video_connector_num_attention_heads ,
130+ head_dim = video_connector_attention_head_dim ,
131+ layers = video_connector_num_layers ,
132+ num_learnable_registers = video_connector_num_learnable_registers ,
133+ rope_type = rope_type ,
134+ theta = rope_theta ,
135+ base_seq_len = connector_rope_base_seq_len ,
136+ double_precision = rope_double_precision ,
137+ attention_kernel = attention_kernel ,
138+ mesh = mesh ,
139+ rngs = rngs ,
140+ gated_attn = video_gated_attn ,
141+ )
142+ self .audio_embeddings_connector = Embeddings1DConnector (
143+ input_dim = a_dim ,
144+ heads = audio_connector_num_attention_heads ,
145+ head_dim = audio_connector_attention_head_dim ,
146+ layers = audio_connector_num_layers ,
147+ num_learnable_registers = audio_connector_num_learnable_registers ,
148+ rope_type = rope_type ,
149+ theta = rope_theta ,
150+ base_seq_len = connector_rope_base_seq_len ,
151+ double_precision = rope_double_precision ,
152+ attention_kernel = attention_kernel ,
153+ mesh = mesh ,
154+ rngs = rngs ,
155+ gated_attn = audio_gated_attn ,
156+ )
119157
120158 def __call__ (
121159 self ,
122160 hidden_states : Union [Tuple [Array , ...], List [Array ]],
123161 attention_mask : Array ,
124- ) -> Tuple [Array , Array ]:
162+ ) -> Tuple [Array , Array , Array ]:
125163 """
126164 Returns:
127165 (video_embeds, audio_embeds, new_attention_mask)
128166 """
129167 with jax .named_scope ("Text Encoder Forward" ):
130- # 1. Shared Feature Extraction
131- features = self .feature_extractor (hidden_states , attention_mask )
132-
133- # 2. Parallel Connection
134168 if self .per_modality_projections :
135- video_features , audio_features = features
136- video_embeds , new_attention_mask = self .video_embeddings_connector (video_features , attention_mask )
137- audio_embeds , _ = self .audio_embeddings_connector (audio_features , attention_mask )
169+ # 1. Stack Hidden States if needed
170+ if isinstance (hidden_states , (tuple , list )):
171+ x = jnp .stack (hidden_states , axis = - 1 )
172+ else :
173+ x = hidden_states
174+
175+ b , l , d , k = x .shape
176+
177+ # 2. Per-token RMS norm
178+ variance = jnp .mean (x ** 2 , axis = 2 , keepdims = True )
179+ norm_text_encoder_hidden_states = x * jax .lax .rsqrt (variance + 1e-6 )
180+
181+ norm_text_encoder_hidden_states = norm_text_encoder_hidden_states .reshape (b , l , - 1 )
182+
183+ bool_mask = (attention_mask > 0.5 ).astype (jnp .float32 )[..., None ]
184+ norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
185+
186+ # 3. Rescale norms
187+ # Using self.caption_channels if available, or fallback to config or 3840
188+ cap_channels = getattr (self , "caption_channels" , getattr (self .config , "caption_channels" , 3840 ))
189+
190+ video_scale_factor = jnp .sqrt (self .video_connector .dim / cap_channels )
191+ video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
192+ audio_scale_factor = jnp .sqrt (self .audio_connector .dim / cap_channels )
193+ audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
194+
195+ video_text_emb_proj = self .video_text_proj_in (video_norm_text_emb )
196+ audio_text_emb_proj = self .audio_text_proj_in (audio_norm_text_emb )
197+
198+ video_embeds , new_attention_mask = self .video_connector (video_text_emb_proj , attention_mask )
199+ audio_embeds , _ = self .audio_connector (audio_text_emb_proj , attention_mask )
138200 else :
201+ # 1. Shared Feature Extraction
202+ features = self .feature_extractor (hidden_states , attention_mask )
203+
204+ # 2. Parallel Connection
139205 video_embeds , new_attention_mask = self .video_embeddings_connector (features , attention_mask )
140206 audio_embeds , _ = self .audio_embeddings_connector (features , attention_mask )
141207
0 commit comments