2727DType = common_types .DType
2828
2929
30- class LTX2VideoGemmaTextEncoder (nnx .Module ):
30+ class LTX2EmbeddingsProcessor (nnx .Module ):
3131 """
32- Encoder for Video-only tasks.
33- Pipeline: Gemma Hidden States -> Feature Extractor -> Video Connector -> Output
34- """
35-
36- def __init__ (
37- self ,
38- # Feature Extractor Config
39- gemma_dim : int = 3840 , # Gemma-3-12b
40- gemma_layers : int = 49 , # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
41- projection_dim : int = 3840 , # LTX-2 conditioning dim
42- # Connector Config
43- connector_heads : int = 32 ,
44- connector_head_dim : int = 128 ,
45- connector_layers : int = 2 ,
46- num_thinking_tokens : int = 128 ,
47- dtype : DType = jnp .float32 ,
48- attention_kernel : str = "flash" ,
49- mesh : jax .sharding .Mesh = None ,
50- rngs : nnx .Rngs = None ,
51- ):
52- input_dim = gemma_dim * gemma_layers
53-
54- self .feature_extractor = LTX2GemmaFeatureExtractor (
55- input_dim = input_dim ,
56- output_dim = projection_dim ,
57- dtype = dtype ,
58- rngs = rngs ,
59- )
60-
61- self .embeddings_connector = Embeddings1DConnector (
62- input_dim = projection_dim ,
63- heads = connector_heads ,
64- head_dim = connector_head_dim ,
65- layers = connector_layers ,
66- num_learnable_registers = num_thinking_tokens ,
67- rope_type = "interleaved" ,
68- attention_kernel = attention_kernel ,
69- mesh = mesh ,
70- rngs = rngs ,
71- )
72-
73- def __call__ (
74- self ,
75- hidden_states : Union [Tuple [Array , ...], List [Array ]],
76- attention_mask : Array ,
77- ) -> Array :
78- """
79- Args:
80- hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D])
81- attention_mask: [B, T]
82- """
83- # 1. Feature Extraction (Stack -> Norm -> Project)
84- features = self .feature_extractor (hidden_states , attention_mask )
32+ Wraps feature extractor + video connector + audio connector.
33+ Mirrors diffusers LTX2TextConnectors.
8534
86- # 2. Connection (Refine + Thinking Tokens)
87- video_embeds = self .embeddings_connector (features , attention_mask )
88-
89- return video_embeds
90-
91-
92- class LTX2AudioVideoGemmaTextEncoder (nnx .Module ):
93- """
94- Encoder for Audio-Video tasks.
9535 Pipeline: Gemma Hidden States -> Feature Extractor -> [Video Connector, Audio Connector]
9636 """
9737
9838 def __init__ (
9939 self ,
100- # Feature Extractor Config (Shared)
40+ # Feature Extractor Config
10141 gemma_dim : int = 3840 , # Gemma-3-12b
10242 gemma_layers : int = 49 , # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
103- projection_dim : int = 3840 ,
104- # Connector Config
43+ projection_dim : int = 3840 , # LTX-2 conditioning dim
44+ # Video Connector Config
10545 connector_heads : int = 30 ,
10646 connector_head_dim : int = 128 ,
10747 connector_layers : int = 2 ,
10848 num_thinking_tokens : int = 128 ,
49+ # Audio Connector Config (defaults to same as video if not specified)
50+ audio_connector_heads : int = 30 ,
51+ audio_connector_head_dim : int = 128 ,
52+ audio_connector_layers : int = 2 ,
10953 dtype : DType = jnp .float32 ,
11054 attention_kernel : str = "flash" ,
11155 mesh : jax .sharding .Mesh = None ,
@@ -120,8 +64,8 @@ def __init__(
12064 rngs = rngs ,
12165 )
12266
123- # Two independent connectors
124- self .video_embeddings_connector = Embeddings1DConnector (
67+ # Video connector
68+ self .video_connector = Embeddings1DConnector (
12569 input_dim = projection_dim ,
12670 heads = connector_heads ,
12771 head_dim = connector_head_dim ,
@@ -133,11 +77,12 @@ def __init__(
13377 rngs = rngs ,
13478 )
13579
136- self .audio_embeddings_connector = Embeddings1DConnector (
80+ # Audio connector
81+ self .audio_connector = Embeddings1DConnector (
13782 input_dim = projection_dim ,
138- heads = connector_heads ,
139- head_dim = connector_head_dim ,
140- layers = connector_layers ,
83+ heads = audio_connector_heads ,
84+ head_dim = audio_connector_head_dim ,
85+ layers = audio_connector_layers ,
14186 num_learnable_registers = num_thinking_tokens ,
14287 rope_type = "interleaved" ,
14388 attention_kernel = attention_kernel ,
@@ -151,14 +96,20 @@ def __call__(
15196 attention_mask : Array ,
15297 ) -> Tuple [Array , Array ]:
15398 """
99+ Args:
100+ hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D])
101+ attention_mask: [B, T]
102+
154103 Returns:
155104 (video_embeds, audio_embeds)
156105 """
157- # 1. Shared Feature Extraction
106+ # 1. Feature Extraction (Stack -> Norm -> Project)
158107 features = self .feature_extractor (hidden_states , attention_mask )
159108
160- # 2. Parallel Connection
161- video_embeds = self .video_embeddings_connector (features , attention_mask )
162- audio_embeds = self .audio_embeddings_connector (features , attention_mask )
109+ # 2. Video Connector
110+ video_embeds = self .video_connector (features , attention_mask )
111+
112+ # 3. Audio Connector
113+ audio_embeds = self .audio_connector (features , attention_mask )
163114
164115 return video_embeds , audio_embeds
0 commit comments