1414limitations under the License.
1515"""
1616
17- from typing import Tuple , Union , List
17+ from typing import Optional , Tuple , Union , List
1818import jax
1919import jax .numpy as jnp
2020from flax import nnx
@@ -39,6 +39,8 @@ class LTX2AudioVideoGemmaTextEncoder(nnx.Module, FlaxModelMixin, ConfigMixin):
3939 def __init__ (
4040 self ,
4141 caption_channels : int = 3840 ,
42+ video_caption_channels : Optional [int ] = None ,
43+ audio_caption_channels : Optional [int ] = None ,
4244 text_proj_in_factor : int = 49 ,
4345 video_connector_attention_head_dim : int = 128 ,
4446 video_connector_num_attention_heads : int = 30 ,
@@ -65,6 +67,9 @@ def __init__(
6567 ):
6668 input_dim = caption_channels * text_proj_in_factor
6769
70+ v_dim = video_caption_channels if video_caption_channels is not None else caption_channels
71+ a_dim = audio_caption_channels if audio_caption_channels is not None else caption_channels
72+
6873 self .per_modality_projections = per_modality_projections
6974
7075 self .feature_extractor = LTX2GemmaFeatureExtractor (
@@ -78,7 +83,7 @@ def __init__(
7883
7984 # Two independent connectors
8085 self .video_embeddings_connector = Embeddings1DConnector (
81- input_dim = caption_channels ,
86+ input_dim = v_dim ,
8287 heads = video_connector_num_attention_heads ,
8388 head_dim = video_connector_attention_head_dim ,
8489 layers = video_connector_num_layers ,
@@ -94,7 +99,7 @@ def __init__(
9499 )
95100
96101 self .audio_embeddings_connector = Embeddings1DConnector (
97- input_dim = caption_channels ,
102+ input_dim = a_dim ,
98103 heads = audio_connector_num_attention_heads ,
99104 head_dim = audio_connector_attention_head_dim ,
100105 layers = audio_connector_num_layers ,
0 commit comments