@@ -261,3 +261,118 @@ def block_scan_fn(carry, block_module):
261261 hidden_states = self .final_norm (hidden_states )
262262
263263 return hidden_states , attention_mask
264+
265+
266+ class LTX2TextConnectors (nnx .Module ):
267+
268+ def __init__ (
269+ self ,
270+ caption_channels : int = 3840 ,
271+ text_proj_in_factor : int = 49 ,
272+ video_connector_num_attention_heads : int = 30 ,
273+ video_connector_attention_head_dim : int = 128 ,
274+ video_connector_num_layers : int = 2 ,
275+ video_connector_num_learnable_registers : int = 128 ,
276+ video_gated_attn : bool = False ,
277+ audio_connector_num_attention_heads : int = 30 ,
278+ audio_connector_attention_head_dim : int = 128 ,
279+ audio_connector_num_layers : int = 2 ,
280+ audio_connector_num_learnable_registers : int = 128 ,
281+ audio_gated_attn : bool = False ,
282+ connector_rope_base_seq_len : int = 4096 ,
283+ rope_theta : float = 10000.0 ,
284+ rope_double_precision : bool = True ,
285+ rope_type : str = "interleaved" ,
286+ per_modality_projections : bool = False ,
287+ video_hidden_dim : int = 4096 ,
288+ audio_hidden_dim : int = 2048 ,
289+ proj_bias : bool = False ,
290+ attention_kernel : str = "flash" ,
291+ mesh : jax .sharding .Mesh = None ,
292+ rngs : nnx .Rngs = None ,
293+ ):
294+ text_encoder_dim = caption_channels * text_proj_in_factor
295+ self .per_modality_projections = per_modality_projections
296+ self .caption_channels = caption_channels
297+ self .video_hidden_dim = video_hidden_dim
298+ self .audio_hidden_dim = audio_hidden_dim
299+
300+ if per_modality_projections :
301+ self .video_text_proj_in = nnx .Linear (
302+ in_features = text_encoder_dim , out_features = video_hidden_dim , use_bias = proj_bias , rngs = rngs
303+ )
304+ self .audio_text_proj_in = nnx .Linear (
305+ in_features = text_encoder_dim , out_features = audio_hidden_dim , use_bias = proj_bias , rngs = rngs
306+ )
307+ else :
308+ self .text_proj_in = nnx .Linear (
309+ in_features = text_encoder_dim , out_features = caption_channels , use_bias = proj_bias , rngs = rngs
310+ )
311+
312+ self .video_connector = Embeddings1DConnector (
313+ input_dim = video_hidden_dim if per_modality_projections else caption_channels ,
314+ heads = video_connector_num_attention_heads ,
315+ head_dim = video_connector_attention_head_dim ,
316+ layers = video_connector_num_layers ,
317+ theta = rope_theta ,
318+ num_learnable_registers = video_connector_num_learnable_registers ,
319+ rope_type = rope_type ,
320+ base_seq_len = connector_rope_base_seq_len ,
321+ double_precision = rope_double_precision ,
322+ attention_kernel = attention_kernel ,
323+ mesh = mesh ,
324+ rngs = rngs ,
325+ gated_attn = video_gated_attn ,
326+ )
327+
328+ self .audio_connector = Embeddings1DConnector (
329+ input_dim = audio_hidden_dim if per_modality_projections else caption_channels ,
330+ heads = audio_connector_num_attention_heads ,
331+ head_dim = audio_connector_attention_head_dim ,
332+ layers = audio_connector_num_layers ,
333+ theta = rope_theta ,
334+ num_learnable_registers = audio_connector_num_learnable_registers ,
335+ rope_type = rope_type ,
336+ base_seq_len = connector_rope_base_seq_len ,
337+ double_precision = rope_double_precision ,
338+ attention_kernel = attention_kernel ,
339+ mesh = mesh ,
340+ rngs = rngs ,
341+ gated_attn = audio_gated_attn ,
342+ )
343+
344+ def __call__ (self , text_encoder_hidden_states : Array , attention_mask : Array ) -> Tuple [Array , Array , Array ]:
345+
346+ if text_encoder_hidden_states .ndim == 3 :
347+ b , l , d = text_encoder_hidden_states .shape
348+ text_proj_in_factor = d // self .caption_channels
349+ text_encoder_hidden_states = text_encoder_hidden_states .reshape (b , l , self .caption_channels , text_proj_in_factor )
350+ else :
351+ b , l , _ , _ = text_encoder_hidden_states .shape
352+
353+ if self .per_modality_projections :
354+ # LTX-2.3
355+ # per_token_rms_norm
356+ variance = jnp .mean (text_encoder_hidden_states ** 2 , axis = 2 , keepdims = True )
357+ norm_text_encoder_hidden_states = text_encoder_hidden_states * jax .lax .rsqrt (variance + 1e-6 )
358+
359+ norm_text_encoder_hidden_states = norm_text_encoder_hidden_states .reshape (b , l , - 1 )
360+
361+ bool_mask = (attention_mask > 0.5 ).astype (jnp .float32 )[..., None ]
362+ norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
363+
364+ # Rescale norms
365+ video_scale_factor = jnp .sqrt (self .video_hidden_dim / self .caption_channels )
366+ video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
367+ audio_scale_factor = jnp .sqrt (self .audio_hidden_dim / self .caption_channels )
368+ audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
369+
370+ video_text_emb_proj = self .video_text_proj_in (video_norm_text_emb )
371+ audio_text_emb_proj = self .audio_text_proj_in (audio_norm_text_emb )
372+ else :
373+ raise NotImplementedError ("LTX-2.0 path in LTX2TextConnectors not fully implemented yet." )
374+
375+ video_text_embedding , video_attn_mask = self .video_connector (video_text_emb_proj , attention_mask )
376+ audio_text_embedding , _ = self .audio_connector (audio_text_emb_proj , attention_mask )
377+
378+ return video_text_embedding , audio_text_embedding , video_attn_mask
0 commit comments