Skip to content

Commit a41a80d

Browse files
committed
Refactor LTX2 text encoders: replace Video/AV classes with unified EmbeddingsProcessor; move tests to tests/ltx2/
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 02dbc99 commit a41a80d

4 files changed

Lines changed: 37 additions & 104 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 27 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -27,85 +27,29 @@
2727
DType = common_types.DType
2828

2929

30-
class LTX2VideoGemmaTextEncoder(nnx.Module):
30+
class LTX2EmbeddingsProcessor(nnx.Module):
3131
"""
32-
Encoder for Video-only tasks.
33-
Pipeline: Gemma Hidden States -> Feature Extractor -> Video Connector -> Output
34-
"""
35-
36-
def __init__(
37-
self,
38-
# Feature Extractor Config
39-
gemma_dim: int = 3840, # Gemma-3-12b
40-
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
41-
projection_dim: int = 3840, # LTX-2 conditioning dim
42-
# Connector Config
43-
connector_heads: int = 32,
44-
connector_head_dim: int = 128,
45-
connector_layers: int = 2,
46-
num_thinking_tokens: int = 128,
47-
dtype: DType = jnp.float32,
48-
attention_kernel: str = "flash",
49-
mesh: jax.sharding.Mesh = None,
50-
rngs: nnx.Rngs = None,
51-
):
52-
input_dim = gemma_dim * gemma_layers
53-
54-
self.feature_extractor = LTX2GemmaFeatureExtractor(
55-
input_dim=input_dim,
56-
output_dim=projection_dim,
57-
dtype=dtype,
58-
rngs=rngs,
59-
)
60-
61-
self.embeddings_connector = Embeddings1DConnector(
62-
input_dim=projection_dim,
63-
heads=connector_heads,
64-
head_dim=connector_head_dim,
65-
layers=connector_layers,
66-
num_learnable_registers=num_thinking_tokens,
67-
rope_type="interleaved",
68-
attention_kernel=attention_kernel,
69-
mesh=mesh,
70-
rngs=rngs,
71-
)
72-
73-
def __call__(
74-
self,
75-
hidden_states: Union[Tuple[Array, ...], List[Array]],
76-
attention_mask: Array,
77-
) -> Array:
78-
"""
79-
Args:
80-
hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D])
81-
attention_mask: [B, T]
82-
"""
83-
# 1. Feature Extraction (Stack -> Norm -> Project)
84-
features = self.feature_extractor(hidden_states, attention_mask)
32+
Wraps feature extractor + video connector + audio connector.
33+
Mirrors diffusers LTX2TextConnectors.
8534
86-
# 2. Connection (Refine + Thinking Tokens)
87-
video_embeds = self.embeddings_connector(features, attention_mask)
88-
89-
return video_embeds
90-
91-
92-
class LTX2AudioVideoGemmaTextEncoder(nnx.Module):
93-
"""
94-
Encoder for Audio-Video tasks.
9535
Pipeline: Gemma Hidden States -> Feature Extractor -> [Video Connector, Audio Connector]
9636
"""
9737

9838
def __init__(
9939
self,
100-
# Feature Extractor Config (Shared)
40+
# Feature Extractor Config
10141
gemma_dim: int = 3840, # Gemma-3-12b
10242
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
103-
projection_dim: int = 3840,
104-
# Connector Config
43+
projection_dim: int = 3840, # LTX-2 conditioning dim
44+
# Video Connector Config
10545
connector_heads: int = 30,
10646
connector_head_dim: int = 128,
10747
connector_layers: int = 2,
10848
num_thinking_tokens: int = 128,
49+
# Audio Connector Config (defaults to same as video if not specified)
50+
audio_connector_heads: int = 30,
51+
audio_connector_head_dim: int = 128,
52+
audio_connector_layers: int = 2,
10953
dtype: DType = jnp.float32,
11054
attention_kernel: str = "flash",
11155
mesh: jax.sharding.Mesh = None,
@@ -120,8 +64,8 @@ def __init__(
12064
rngs=rngs,
12165
)
12266

123-
# Two independent connectors
124-
self.video_embeddings_connector = Embeddings1DConnector(
67+
# Video connector
68+
self.video_connector = Embeddings1DConnector(
12569
input_dim=projection_dim,
12670
heads=connector_heads,
12771
head_dim=connector_head_dim,
@@ -133,11 +77,12 @@ def __init__(
13377
rngs=rngs,
13478
)
13579

136-
self.audio_embeddings_connector = Embeddings1DConnector(
80+
# Audio connector
81+
self.audio_connector = Embeddings1DConnector(
13782
input_dim=projection_dim,
138-
heads=connector_heads,
139-
head_dim=connector_head_dim,
140-
layers=connector_layers,
83+
heads=audio_connector_heads,
84+
head_dim=audio_connector_head_dim,
85+
layers=audio_connector_layers,
14186
num_learnable_registers=num_thinking_tokens,
14287
rope_type="interleaved",
14388
attention_kernel=attention_kernel,
@@ -151,14 +96,20 @@ def __call__(
15196
attention_mask: Array,
15297
) -> Tuple[Array, Array]:
15398
"""
99+
Args:
100+
hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D])
101+
attention_mask: [B, T]
102+
154103
Returns:
155104
(video_embeds, audio_embeds)
156105
"""
157-
# 1. Shared Feature Extraction
106+
# 1. Feature Extraction (Stack -> Norm -> Project)
158107
features = self.feature_extractor(hidden_states, attention_mask)
159108

160-
# 2. Parallel Connection
161-
video_embeds = self.video_embeddings_connector(features, attention_mask)
162-
audio_embeds = self.audio_embeddings_connector(features, attention_mask)
109+
# 2. Video Connector
110+
video_embeds = self.video_connector(features, attention_mask)
111+
112+
# 3. Audio Connector
113+
audio_embeds = self.audio_connector(features, attention_mask)
163114

164115
return video_embeds, audio_embeds

src/maxdiffusion/tests/test_embeddings_connector_ltx2.py renamed to src/maxdiffusion/tests/ltx2/test_embeddings_connector_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax.numpy as jnp
1919
import numpy as np
2020
from flax import nnx
21-
from ..models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector
21+
from maxdiffusion.models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector
2222

2323

2424
class Embeddings1DConnectorTest(unittest.TestCase):

src/maxdiffusion/tests/test_feature_extractor_ltx2.py renamed to src/maxdiffusion/tests/ltx2/test_feature_extractor_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import jax.numpy as jnp
2121
from flax import nnx
2222

23-
from ..models.ltx2.text_encoders.feature_extractor_ltx2 import LTX2GemmaFeatureExtractor, _norm_and_concat_padded_batch
23+
from maxdiffusion.models.ltx2.text_encoders.feature_extractor_ltx2 import LTX2GemmaFeatureExtractor, _norm_and_concat_padded_batch
2424

2525

2626
# ==========================================

src/maxdiffusion/tests/test_text_encoders_ltx2.py renamed to src/maxdiffusion/tests/ltx2/test_text_encoders_ltx2.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax.numpy as jnp
1919
import numpy as np
2020
from flax import nnx
21-
from ..models.ltx2.text_encoders.text_encoders_ltx2 import LTX2VideoGemmaTextEncoder, LTX2AudioVideoGemmaTextEncoder
21+
from maxdiffusion.models.ltx2.text_encoders.text_encoders_ltx2 import LTX2EmbeddingsProcessor
2222

2323

2424
class LTX2TextEncodersTest(unittest.TestCase):
@@ -36,53 +36,35 @@ def setUp(self):
3636

3737
self.attention_mask = jnp.ones((self.B, self.T))
3838

39-
def test_video_encoder_forward(self):
40-
encoder = LTX2VideoGemmaTextEncoder(
39+
def test_embeddings_processor_forward(self):
40+
processor = LTX2EmbeddingsProcessor(
4141
gemma_dim=self.gemma_dim,
4242
gemma_layers=self.gemma_layers,
4343
projection_dim=self.proj_dim,
4444
connector_heads=4,
4545
connector_head_dim=16,
4646
connector_layers=1,
4747
num_thinking_tokens=8,
48+
audio_connector_heads=4,
49+
audio_connector_head_dim=16,
50+
audio_connector_layers=1,
4851
attention_kernel="dot_product",
4952
mesh=None,
5053
rngs=self.rng,
5154
)
5255

53-
output = encoder(tuple(self.hidden_states), self.attention_mask)
54-
55-
# Expected shape: [B, T, proj_dim]
56-
self.assertEqual(output.shape, (self.B, self.T, self.proj_dim))
57-
print("\n[PASS] Video Encoder Forward Pass Verified.")
58-
59-
def test_av_encoder_forward(self):
60-
encoder = LTX2AudioVideoGemmaTextEncoder(
61-
gemma_dim=self.gemma_dim,
62-
gemma_layers=self.gemma_layers,
63-
projection_dim=self.proj_dim,
64-
connector_heads=4,
65-
connector_head_dim=16,
66-
connector_layers=1,
67-
num_thinking_tokens=8,
68-
attention_kernel="dot_product",
69-
mesh=None,
70-
rngs=self.rng,
71-
)
72-
73-
video_out, audio_out = encoder(tuple(self.hidden_states), self.attention_mask)
56+
video_out, audio_out = processor(tuple(self.hidden_states), self.attention_mask)
7457

7558
# Expected shapes: Both [B, T, proj_dim]
7659
self.assertEqual(video_out.shape, (self.B, self.T, self.proj_dim))
7760
self.assertEqual(audio_out.shape, (self.B, self.T, self.proj_dim))
7861

7962
# Ensure they are different (different random init for connectors)
80-
# Note: In reality they are initialized differently, so outputs should differ
8163
self.assertFalse(
8264
jnp.allclose(video_out, audio_out), "Video and Audio outputs should differ due to different connector weights"
8365
)
8466

85-
print("\n[PASS] Audio-Video Encoder Forward Pass Verified.")
67+
print("\n[PASS] Embeddings Processor Forward Pass Verified.")
8668

8769

8870
if __name__ == "__main__":

0 commit comments

Comments
 (0)