Skip to content

Commit 85df9d2

Browse files
committed
changes in text encoder and feature extractor file
1 parent 9b5ed92 commit 85df9d2

2 files changed

Lines changed: 153 additions & 49 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,25 @@ def __init__(
102102
output_dim: int,
103103
dtype: DType = jnp.float32,
104104
rngs: nnx.Rngs = None,
105+
per_modality_projections: bool = False,
106+
use_bias: bool = False,
107+
video_output_dim: Optional[int] = None,
108+
audio_output_dim: Optional[int] = None,
105109
):
106110
"""
107111
Args:
108112
input_dim: Dimension of flattened hidden states (Gemma dim * Num layers).
109-
output_dim: Target dimension for diffusion conditioning.
113+
output_dim: Target dimension for diffusion conditioning (fallback).
110114
"""
111-
# LTX-2 uses bias=False for the projection
112-
self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs)
115+
self.per_modality_projections = per_modality_projections
116+
117+
if per_modality_projections:
118+
v_dim = video_output_dim if video_output_dim is not None else output_dim
119+
a_dim = audio_output_dim if audio_output_dim is not None else output_dim
120+
self.video_linear = nnx.Linear(input_dim, v_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
121+
self.audio_linear = nnx.Linear(input_dim, a_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
122+
else:
123+
self.linear = nnx.Linear(input_dim, output_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
113124

114125
def __call__(self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array) -> Array:
115126
"""
@@ -133,4 +144,7 @@ def __call__(self, hidden_states: Union[Tuple[Array, ...], Array], attention_mas
133144
x_norm = _norm_and_concat_padded_batch(x, attention_mask)
134145

135146
# 4. Projection
136-
return self.linear(x_norm)
147+
if self.per_modality_projections:
148+
return self.video_linear(x_norm), self.audio_linear(x_norm)
149+
else:
150+
return self.linear(x_norm)

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 135 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class LTX2AudioVideoGemmaTextEncoder(nnx.Module, FlaxModelMixin, ConfigMixin):
3939
def __init__(
4040
self,
4141
caption_channels: int = 3840,
42+
video_caption_channels: Optional[int] = None,
43+
audio_caption_channels: Optional[int] = None,
4244
text_proj_in_factor: int = 49,
4345
video_connector_attention_head_dim: int = 128,
4446
video_connector_num_attention_heads: int = 30,
@@ -57,63 +59,151 @@ def __init__(
5759
attention_kernel: str = "flash",
5860
mesh: jax.sharding.Mesh = None,
5961
rngs: nnx.Rngs = None,
62+
per_modality_projections: bool = False,
63+
proj_bias: bool = False,
64+
video_gated_attn: bool = False,
65+
audio_gated_attn: bool = False,
66+
audio_hidden_dim: Optional[int] = None,
67+
video_hidden_dim: Optional[int] = None,
6068
**kwargs,
6169
):
62-
input_dim = caption_channels * text_proj_in_factor
63-
64-
self.feature_extractor = LTX2GemmaFeatureExtractor(
65-
input_dim=input_dim,
66-
output_dim=caption_channels,
67-
dtype=dtype,
68-
rngs=rngs,
69-
)
70-
71-
# Two independent connectors
72-
self.video_embeddings_connector = Embeddings1DConnector(
73-
input_dim=caption_channels,
74-
heads=video_connector_num_attention_heads,
75-
head_dim=video_connector_attention_head_dim,
76-
layers=video_connector_num_layers,
77-
num_learnable_registers=video_connector_num_learnable_registers,
78-
rope_type=rope_type,
79-
theta=rope_theta,
80-
base_seq_len=connector_rope_base_seq_len,
81-
double_precision=rope_double_precision,
82-
attention_kernel=attention_kernel,
83-
mesh=mesh,
84-
rngs=rngs,
85-
)
86-
87-
self.audio_embeddings_connector = Embeddings1DConnector(
88-
input_dim=caption_channels,
89-
heads=audio_connector_num_attention_heads,
90-
head_dim=audio_connector_attention_head_dim,
91-
layers=audio_connector_num_layers,
92-
num_learnable_registers=audio_connector_num_learnable_registers,
93-
rope_type=rope_type,
94-
theta=rope_theta,
95-
base_seq_len=connector_rope_base_seq_len,
96-
double_precision=rope_double_precision,
97-
attention_kernel=attention_kernel,
98-
mesh=mesh,
99-
rngs=rngs,
100-
)
70+
gemma_dim = 3840 if video_caption_channels is not None else caption_channels
71+
input_dim = gemma_dim * text_proj_in_factor
72+
73+
v_dim = video_hidden_dim if video_hidden_dim is not None else (video_caption_channels if video_caption_channels is not None else caption_channels)
74+
a_dim = audio_hidden_dim if audio_hidden_dim is not None else (audio_caption_channels if audio_caption_channels is not None else caption_channels)
75+
76+
self.per_modality_projections = per_modality_projections
77+
78+
if per_modality_projections:
79+
self.video_text_proj_in = nnx.Linear(
80+
in_features=input_dim, out_features=v_dim, use_bias=proj_bias, rngs=rngs
81+
)
82+
self.audio_text_proj_in = nnx.Linear(
83+
in_features=input_dim, out_features=a_dim, use_bias=proj_bias, rngs=rngs
84+
)
85+
86+
self.video_embeddings_connector = Embeddings1DConnector(
87+
input_dim=v_dim,
88+
heads=video_connector_num_attention_heads,
89+
head_dim=video_connector_attention_head_dim,
90+
layers=video_connector_num_layers,
91+
num_learnable_registers=video_connector_num_learnable_registers,
92+
rope_type=rope_type,
93+
theta=rope_theta,
94+
base_seq_len=connector_rope_base_seq_len,
95+
double_precision=rope_double_precision,
96+
attention_kernel=attention_kernel,
97+
mesh=mesh,
98+
rngs=rngs,
99+
gated_attn=video_gated_attn,
100+
)
101+
self.audio_embeddings_connector = Embeddings1DConnector(
102+
input_dim=a_dim,
103+
heads=audio_connector_num_attention_heads,
104+
head_dim=audio_connector_attention_head_dim,
105+
layers=audio_connector_num_layers,
106+
num_learnable_registers=audio_connector_num_learnable_registers,
107+
rope_type=rope_type,
108+
theta=rope_theta,
109+
base_seq_len=connector_rope_base_seq_len,
110+
double_precision=rope_double_precision,
111+
attention_kernel=attention_kernel,
112+
mesh=mesh,
113+
rngs=rngs,
114+
gated_attn=audio_gated_attn,
115+
)
116+
else:
117+
self.feature_extractor = LTX2GemmaFeatureExtractor(
118+
input_dim=input_dim,
119+
output_dim=caption_channels,
120+
dtype=dtype,
121+
rngs=rngs,
122+
per_modality_projections=per_modality_projections,
123+
use_bias=proj_bias,
124+
video_output_dim=v_dim,
125+
audio_output_dim=a_dim,
126+
)
127+
128+
# Two independent connectors
129+
self.video_embeddings_connector = Embeddings1DConnector(
130+
input_dim=v_dim,
131+
heads=video_connector_num_attention_heads,
132+
head_dim=video_connector_attention_head_dim,
133+
layers=video_connector_num_layers,
134+
num_learnable_registers=video_connector_num_learnable_registers,
135+
rope_type=rope_type,
136+
theta=rope_theta,
137+
base_seq_len=connector_rope_base_seq_len,
138+
double_precision=rope_double_precision,
139+
attention_kernel=attention_kernel,
140+
mesh=mesh,
141+
rngs=rngs,
142+
gated_attn=video_gated_attn,
143+
)
144+
self.audio_embeddings_connector = Embeddings1DConnector(
145+
input_dim=a_dim,
146+
heads=audio_connector_num_attention_heads,
147+
head_dim=audio_connector_attention_head_dim,
148+
layers=audio_connector_num_layers,
149+
num_learnable_registers=audio_connector_num_learnable_registers,
150+
rope_type=rope_type,
151+
theta=rope_theta,
152+
base_seq_len=connector_rope_base_seq_len,
153+
double_precision=rope_double_precision,
154+
attention_kernel=attention_kernel,
155+
mesh=mesh,
156+
rngs=rngs,
157+
gated_attn=audio_gated_attn,
158+
)
101159

102160
def __call__(
103161
self,
104162
hidden_states: Union[Tuple[Array, ...], List[Array]],
105163
attention_mask: Array,
106-
) -> Tuple[Array, Array]:
164+
) -> Tuple[Array, Array, Array]:
107165
"""
108166
Returns:
109167
(video_embeds, audio_embeds, new_attention_mask)
110168
"""
111169
with jax.named_scope("Text Encoder Forward"):
112-
# 1. Shared Feature Extraction
113-
features = self.feature_extractor(hidden_states, attention_mask)
170+
if self.per_modality_projections:
171+
# 1. Stack Hidden States if needed
172+
if isinstance(hidden_states, (tuple, list)):
173+
x = jnp.stack(hidden_states, axis=-1)
174+
else:
175+
x = hidden_states
176+
177+
b, l, d, k = x.shape
178+
179+
# 2. Per-token RMS norm
180+
variance = jnp.mean(x**2, axis=2, keepdims=True)
181+
norm_text_encoder_hidden_states = x * jax.lax.rsqrt(variance + 1e-6)
182+
183+
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.reshape(b, l, -1)
184+
185+
bool_mask = (attention_mask > 0.5).astype(jnp.float32)[..., None]
186+
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
187+
188+
# 3. Rescale norms
189+
cap_channels = getattr(self, "caption_channels", getattr(self.config, "caption_channels", 3840))
190+
191+
video_scale_factor = jnp.sqrt(self.video_embeddings_connector.dim / cap_channels)
192+
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
193+
audio_scale_factor = jnp.sqrt(self.audio_embeddings_connector.dim / cap_channels)
194+
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
195+
196+
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
197+
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
198+
199+
video_embeds, new_attention_mask = self.video_embeddings_connector(video_text_emb_proj, attention_mask)
200+
audio_embeds, _ = self.audio_embeddings_connector(audio_text_emb_proj, attention_mask)
201+
else:
202+
# 1. Shared Feature Extraction
203+
features = self.feature_extractor(hidden_states, attention_mask)
114204

115-
# 2. Parallel Connection
116-
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
117-
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
205+
# 2. Parallel Connection
206+
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
207+
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
118208

119209
return video_embeds, audio_embeds, new_attention_mask

0 commit comments

Comments
 (0)