Skip to content

Commit 229d4f4

Browse files
committed
ltx2.3 connectors loading
1 parent e48e104 commit 229d4f4

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Tuple, Union, List
17+
from typing import Optional, Tuple, Union, List
1818
import jax
1919
import jax.numpy as jnp
2020
from flax import nnx
@@ -39,6 +39,8 @@ class LTX2AudioVideoGemmaTextEncoder(nnx.Module, FlaxModelMixin, ConfigMixin):
3939
def __init__(
4040
self,
4141
caption_channels: int = 3840,
42+
video_caption_channels: Optional[int] = None,
43+
audio_caption_channels: Optional[int] = None,
4244
text_proj_in_factor: int = 49,
4345
video_connector_attention_head_dim: int = 128,
4446
video_connector_num_attention_heads: int = 30,
@@ -65,6 +67,9 @@ def __init__(
6567
):
6668
input_dim = caption_channels * text_proj_in_factor
6769

70+
v_dim = video_caption_channels if video_caption_channels is not None else caption_channels
71+
a_dim = audio_caption_channels if audio_caption_channels is not None else caption_channels
72+
6873
self.per_modality_projections = per_modality_projections
6974

7075
self.feature_extractor = LTX2GemmaFeatureExtractor(
@@ -78,7 +83,7 @@ def __init__(
7883

7984
# Two independent connectors
8085
self.video_embeddings_connector = Embeddings1DConnector(
81-
input_dim=caption_channels,
86+
input_dim=v_dim,
8287
heads=video_connector_num_attention_heads,
8388
head_dim=video_connector_attention_head_dim,
8489
layers=video_connector_num_layers,
@@ -94,7 +99,7 @@ def __init__(
9499
)
95100

96101
self.audio_embeddings_connector = Embeddings1DConnector(
97-
input_dim=caption_channels,
102+
input_dim=a_dim,
98103
heads=audio_connector_num_attention_heads,
99104
head_dim=audio_connector_attention_head_dim,
100105
layers=audio_connector_num_layers,

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
329329
"video_connector_num_layers": 8,
330330
"audio_connector_num_layers": 8,
331331
"caption_channels": 2048,
332+
"video_caption_channels": 4096,
333+
"audio_caption_channels": 2048,
332334
"video_connector_num_attention_heads": 32,
333335
"audio_connector_num_attention_heads": 32,
334336
"video_connector_attention_head_dim": 64,

0 commit comments

Comments
 (0)