Skip to content

Commit b7831ee

Browse files
author
James
committed
[Text Pipeline] Implement Text Encoders Wrappers
Signed-off-by: James <shyhuanh@google.com>
1 parent 4612225 commit b7831ee

2 files changed

Lines changed: 242 additions & 0 deletions

File tree

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Optional, Tuple, Union, List
18+
import jax
19+
import jax.numpy as jnp
20+
from flax import nnx
21+
from maxdiffusion import common_types
22+
23+
from .feature_extractor_ltx2 import LTX2GemmaFeatureExtractor
24+
from .embeddings_connector_ltx2 import Embeddings1DConnector
25+
26+
Array = common_types.Array
27+
DType = common_types.DType
28+
29+
class LTX2VideoGemmaTextEncoder(nnx.Module):
30+
"""
31+
Encoder for Video-only tasks.
32+
Pipeline: Gemma Hidden States -> Feature Extractor -> Video Connector -> Output
33+
"""
34+
def __init__(
35+
self,
36+
# Feature Extractor Config
37+
gemma_dim: int = 3840, # Gemma-3-12b
38+
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
39+
projection_dim: int = 4096, # LTX-2 conditioning dim
40+
# Connector Config
41+
connector_heads: int = 32,
42+
connector_head_dim: int = 128,
43+
connector_layers: int = 2,
44+
num_thinking_tokens: int = 128,
45+
dtype: DType = jnp.float32,
46+
attention_kernel: str = "flash",
47+
rngs: nnx.Rngs = None,
48+
):
49+
input_dim = gemma_dim * gemma_layers
50+
51+
self.feature_extractor = LTX2GemmaFeatureExtractor(
52+
input_dim=input_dim,
53+
output_dim=projection_dim,
54+
dtype=dtype,
55+
rngs=rngs,
56+
)
57+
58+
self.embeddings_connector = Embeddings1DConnector(
59+
input_dim=projection_dim,
60+
heads=connector_heads,
61+
head_dim=connector_head_dim,
62+
layers=connector_layers,
63+
num_learnable_registers=num_thinking_tokens,
64+
rope_type="interleaved",
65+
attention_kernel=attention_kernel,
66+
rngs=rngs,
67+
)
68+
69+
def __call__(
70+
self,
71+
hidden_states: Union[Tuple[Array, ...], List[Array]],
72+
attention_mask: Array,
73+
) -> Array:
74+
"""
75+
Args:
76+
hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D])
77+
attention_mask: [B, T]
78+
"""
79+
# 1. Feature Extraction (Stack -> Norm -> Project)
80+
features = self.feature_extractor(hidden_states, attention_mask)
81+
82+
# 2. Connection (Refine + Thinking Tokens)
83+
video_embeds = self.embeddings_connector(features, attention_mask)
84+
85+
return video_embeds
86+
87+
88+
class LTX2AudioVideoGemmaTextEncoder(nnx.Module):
89+
"""
90+
Encoder for Audio-Video tasks.
91+
Pipeline: Gemma Hidden States -> Feature Extractor -> [Video Connector, Audio Connector]
92+
"""
93+
def __init__(
94+
self,
95+
# Feature Extractor Config (Shared)
96+
gemma_dim: int = 3840, # Gemma-3-12b
97+
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
98+
projection_dim: int = 4096,
99+
# Connector Config
100+
connector_heads: int = 32,
101+
connector_head_dim: int = 128,
102+
connector_layers: int = 2,
103+
num_thinking_tokens: int = 128,
104+
dtype: DType = jnp.float32,
105+
attention_kernel: str = "flash",
106+
rngs: nnx.Rngs = None,
107+
):
108+
input_dim = gemma_dim * gemma_layers
109+
110+
self.feature_extractor = LTX2GemmaFeatureExtractor(
111+
input_dim=input_dim,
112+
output_dim=projection_dim,
113+
dtype=dtype,
114+
rngs=rngs,
115+
)
116+
117+
# Two independent connectors
118+
self.video_embeddings_connector = Embeddings1DConnector(
119+
input_dim=projection_dim,
120+
heads=connector_heads,
121+
head_dim=connector_head_dim,
122+
layers=connector_layers,
123+
num_learnable_registers=num_thinking_tokens,
124+
rope_type="interleaved",
125+
attention_kernel=attention_kernel,
126+
rngs=rngs,
127+
)
128+
129+
self.audio_embeddings_connector = Embeddings1DConnector(
130+
input_dim=projection_dim,
131+
heads=connector_heads,
132+
head_dim=connector_head_dim,
133+
layers=connector_layers,
134+
num_learnable_registers=num_thinking_tokens,
135+
rope_type="interleaved",
136+
attention_kernel=attention_kernel,
137+
rngs=rngs,
138+
)
139+
140+
def __call__(
141+
self,
142+
hidden_states: Union[Tuple[Array, ...], List[Array]],
143+
attention_mask: Array,
144+
) -> Tuple[Array, Array]:
145+
"""
146+
Returns:
147+
(video_embeds, audio_embeds)
148+
"""
149+
# 1. Shared Feature Extraction
150+
features = self.feature_extractor(hidden_states, attention_mask)
151+
152+
# 2. Parallel Connection
153+
video_embeds = self.video_embeddings_connector(features, attention_mask)
154+
audio_embeds = self.audio_embeddings_connector(features, attention_mask)
155+
156+
return video_embeds, audio_embeds
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
import jax
19+
import jax.numpy as jnp
20+
import numpy as np
21+
from flax import nnx
22+
from ..models.ltx2.text_encoders.text_encoders_ltx2 import LTX2VideoGemmaTextEncoder, LTX2AudioVideoGemmaTextEncoder
23+
24+
class LTX2TextEncodersTest(unittest.TestCase):
25+
def setUp(self):
26+
self.rng = nnx.Rngs(0)
27+
self.B = 2
28+
self.T = 16
29+
self.gemma_dim = 32
30+
self.gemma_layers = 3
31+
self.proj_dim = 64
32+
33+
# Mock Gemma hidden states
34+
self.hidden_states = [
35+
jnp.array(np.random.randn(self.B, self.T, self.gemma_dim))
36+
for _ in range(self.gemma_layers)
37+
]
38+
39+
self.attention_mask = jnp.ones((self.B, self.T))
40+
41+
def test_video_encoder_forward(self):
42+
encoder = LTX2VideoGemmaTextEncoder(
43+
gemma_dim=self.gemma_dim,
44+
gemma_layers=self.gemma_layers,
45+
projection_dim=self.proj_dim,
46+
connector_heads=4,
47+
connector_head_dim=16,
48+
connector_layers=1,
49+
num_thinking_tokens=8,
50+
attention_kernel="dot_product",
51+
rngs=self.rng
52+
)
53+
54+
output = encoder(tuple(self.hidden_states), self.attention_mask)
55+
56+
# Expected shape: [B, T, proj_dim]
57+
self.assertEqual(output.shape, (self.B, self.T, self.proj_dim))
58+
print("\n[PASS] Video Encoder Forward Pass Verified.")
59+
60+
def test_av_encoder_forward(self):
61+
encoder = LTX2AudioVideoGemmaTextEncoder(
62+
gemma_dim=self.gemma_dim,
63+
gemma_layers=self.gemma_layers,
64+
projection_dim=self.proj_dim,
65+
connector_heads=4,
66+
connector_head_dim=16,
67+
connector_layers=1,
68+
num_thinking_tokens=8,
69+
attention_kernel="dot_product",
70+
rngs=self.rng
71+
)
72+
73+
video_out, audio_out = encoder(tuple(self.hidden_states), self.attention_mask)
74+
75+
# Expected shapes: Both [B, T, proj_dim]
76+
self.assertEqual(video_out.shape, (self.B, self.T, self.proj_dim))
77+
self.assertEqual(audio_out.shape, (self.B, self.T, self.proj_dim))
78+
79+
# Ensure they are different (different random init for connectors)
80+
# Note: In reality they are initialized differently, so outputs should differ
81+
self.assertFalse(jnp.allclose(video_out, audio_out), "Video and Audio outputs should differ due to different connector weights")
82+
83+
print("\n[PASS] Audio-Video Encoder Forward Pass Verified.")
84+
85+
if __name__ == "__main__":
86+
unittest.main()

0 commit comments

Comments
 (0)