@@ -262,117 +262,3 @@ def block_scan_fn(carry, block_module):
262262
263263 return hidden_states , attention_mask
264264
265-
266- class LTX2TextConnectors (nnx .Module ):
267-
268- def __init__ (
269- self ,
270- caption_channels : int = 3840 ,
271- text_proj_in_factor : int = 49 ,
272- video_connector_num_attention_heads : int = 30 ,
273- video_connector_attention_head_dim : int = 128 ,
274- video_connector_num_layers : int = 2 ,
275- video_connector_num_learnable_registers : int = 128 ,
276- video_gated_attn : bool = False ,
277- audio_connector_num_attention_heads : int = 30 ,
278- audio_connector_attention_head_dim : int = 128 ,
279- audio_connector_num_layers : int = 2 ,
280- audio_connector_num_learnable_registers : int = 128 ,
281- audio_gated_attn : bool = False ,
282- connector_rope_base_seq_len : int = 4096 ,
283- rope_theta : float = 10000.0 ,
284- rope_double_precision : bool = True ,
285- rope_type : str = "interleaved" ,
286- per_modality_projections : bool = False ,
287- video_hidden_dim : int = 4096 ,
288- audio_hidden_dim : int = 2048 ,
289- proj_bias : bool = False ,
290- attention_kernel : str = "flash" ,
291- mesh : jax .sharding .Mesh = None ,
292- rngs : nnx .Rngs = None ,
293- ):
294- text_encoder_dim = caption_channels * text_proj_in_factor
295- self .per_modality_projections = per_modality_projections
296- self .caption_channels = caption_channels
297- self .video_hidden_dim = video_hidden_dim
298- self .audio_hidden_dim = audio_hidden_dim
299-
300- if per_modality_projections :
301- self .video_text_proj_in = nnx .Linear (
302- in_features = text_encoder_dim , out_features = video_hidden_dim , use_bias = proj_bias , rngs = rngs
303- )
304- self .audio_text_proj_in = nnx .Linear (
305- in_features = text_encoder_dim , out_features = audio_hidden_dim , use_bias = proj_bias , rngs = rngs
306- )
307- else :
308- self .text_proj_in = nnx .Linear (
309- in_features = text_encoder_dim , out_features = caption_channels , use_bias = proj_bias , rngs = rngs
310- )
311-
312- self .video_connector = Embeddings1DConnector (
313- input_dim = video_hidden_dim if per_modality_projections else caption_channels ,
314- heads = video_connector_num_attention_heads ,
315- head_dim = video_connector_attention_head_dim ,
316- layers = video_connector_num_layers ,
317- theta = rope_theta ,
318- num_learnable_registers = video_connector_num_learnable_registers ,
319- rope_type = rope_type ,
320- base_seq_len = connector_rope_base_seq_len ,
321- double_precision = rope_double_precision ,
322- attention_kernel = attention_kernel ,
323- mesh = mesh ,
324- rngs = rngs ,
325- gated_attn = video_gated_attn ,
326- )
327-
328- self .audio_connector = Embeddings1DConnector (
329- input_dim = audio_hidden_dim if per_modality_projections else caption_channels ,
330- heads = audio_connector_num_attention_heads ,
331- head_dim = audio_connector_attention_head_dim ,
332- layers = audio_connector_num_layers ,
333- theta = rope_theta ,
334- num_learnable_registers = audio_connector_num_learnable_registers ,
335- rope_type = rope_type ,
336- base_seq_len = connector_rope_base_seq_len ,
337- double_precision = rope_double_precision ,
338- attention_kernel = attention_kernel ,
339- mesh = mesh ,
340- rngs = rngs ,
341- gated_attn = audio_gated_attn ,
342- )
343-
344- def __call__ (self , text_encoder_hidden_states : Array , attention_mask : Array ) -> Tuple [Array , Array , Array ]:
345-
346- if text_encoder_hidden_states .ndim == 3 :
347- b , l , d = text_encoder_hidden_states .shape
348- text_proj_in_factor = d // self .caption_channels
349- text_encoder_hidden_states = text_encoder_hidden_states .reshape (b , l , self .caption_channels , text_proj_in_factor )
350- else :
351- b , l , _ , _ = text_encoder_hidden_states .shape
352-
353- if self .per_modality_projections :
354- # LTX-2.3
355- # per_token_rms_norm
356- variance = jnp .mean (text_encoder_hidden_states ** 2 , axis = 2 , keepdims = True )
357- norm_text_encoder_hidden_states = text_encoder_hidden_states * jax .lax .rsqrt (variance + 1e-6 )
358-
359- norm_text_encoder_hidden_states = norm_text_encoder_hidden_states .reshape (b , l , - 1 )
360-
361- bool_mask = (attention_mask > 0.5 ).astype (jnp .float32 )[..., None ]
362- norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
363-
364- # Rescale norms
365- video_scale_factor = jnp .sqrt (self .video_hidden_dim / self .caption_channels )
366- video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
367- audio_scale_factor = jnp .sqrt (self .audio_hidden_dim / self .caption_channels )
368- audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
369-
370- video_text_emb_proj = self .video_text_proj_in (video_norm_text_emb )
371- audio_text_emb_proj = self .audio_text_proj_in (audio_norm_text_emb )
372- else :
373- raise NotImplementedError ("LTX-2.0 path in LTX2TextConnectors not fully implemented yet." )
374-
375- video_text_embedding , video_attn_mask = self .video_connector (video_text_emb_proj , attention_mask )
376- audio_text_embedding , _ = self .audio_connector (audio_text_emb_proj , attention_mask )
377-
378- return video_text_embedding , audio_text_embedding , video_attn_mask
0 commit comments