Skip to content

Commit a6fdc44

Browse files
committed
removed dead code from embeddings_connector
1 parent 73191ca commit a6fdc44

1 file changed

Lines changed: 0 additions & 114 deletions

File tree

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -262,117 +262,3 @@ def block_scan_fn(carry, block_module):
262262

263263
return hidden_states, attention_mask
264264

265-
266-
class LTX2TextConnectors(nnx.Module):
267-
268-
def __init__(
269-
self,
270-
caption_channels: int = 3840,
271-
text_proj_in_factor: int = 49,
272-
video_connector_num_attention_heads: int = 30,
273-
video_connector_attention_head_dim: int = 128,
274-
video_connector_num_layers: int = 2,
275-
video_connector_num_learnable_registers: int = 128,
276-
video_gated_attn: bool = False,
277-
audio_connector_num_attention_heads: int = 30,
278-
audio_connector_attention_head_dim: int = 128,
279-
audio_connector_num_layers: int = 2,
280-
audio_connector_num_learnable_registers: int = 128,
281-
audio_gated_attn: bool = False,
282-
connector_rope_base_seq_len: int = 4096,
283-
rope_theta: float = 10000.0,
284-
rope_double_precision: bool = True,
285-
rope_type: str = "interleaved",
286-
per_modality_projections: bool = False,
287-
video_hidden_dim: int = 4096,
288-
audio_hidden_dim: int = 2048,
289-
proj_bias: bool = False,
290-
attention_kernel: str = "flash",
291-
mesh: jax.sharding.Mesh = None,
292-
rngs: nnx.Rngs = None,
293-
):
294-
text_encoder_dim = caption_channels * text_proj_in_factor
295-
self.per_modality_projections = per_modality_projections
296-
self.caption_channels = caption_channels
297-
self.video_hidden_dim = video_hidden_dim
298-
self.audio_hidden_dim = audio_hidden_dim
299-
300-
if per_modality_projections:
301-
self.video_text_proj_in = nnx.Linear(
302-
in_features=text_encoder_dim, out_features=video_hidden_dim, use_bias=proj_bias, rngs=rngs
303-
)
304-
self.audio_text_proj_in = nnx.Linear(
305-
in_features=text_encoder_dim, out_features=audio_hidden_dim, use_bias=proj_bias, rngs=rngs
306-
)
307-
else:
308-
self.text_proj_in = nnx.Linear(
309-
in_features=text_encoder_dim, out_features=caption_channels, use_bias=proj_bias, rngs=rngs
310-
)
311-
312-
self.video_connector = Embeddings1DConnector(
313-
input_dim=video_hidden_dim if per_modality_projections else caption_channels,
314-
heads=video_connector_num_attention_heads,
315-
head_dim=video_connector_attention_head_dim,
316-
layers=video_connector_num_layers,
317-
theta=rope_theta,
318-
num_learnable_registers=video_connector_num_learnable_registers,
319-
rope_type=rope_type,
320-
base_seq_len=connector_rope_base_seq_len,
321-
double_precision=rope_double_precision,
322-
attention_kernel=attention_kernel,
323-
mesh=mesh,
324-
rngs=rngs,
325-
gated_attn=video_gated_attn,
326-
)
327-
328-
self.audio_connector = Embeddings1DConnector(
329-
input_dim=audio_hidden_dim if per_modality_projections else caption_channels,
330-
heads=audio_connector_num_attention_heads,
331-
head_dim=audio_connector_attention_head_dim,
332-
layers=audio_connector_num_layers,
333-
theta=rope_theta,
334-
num_learnable_registers=audio_connector_num_learnable_registers,
335-
rope_type=rope_type,
336-
base_seq_len=connector_rope_base_seq_len,
337-
double_precision=rope_double_precision,
338-
attention_kernel=attention_kernel,
339-
mesh=mesh,
340-
rngs=rngs,
341-
gated_attn=audio_gated_attn,
342-
)
343-
344-
def __call__(self, text_encoder_hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array, Array]:
345-
346-
if text_encoder_hidden_states.ndim == 3:
347-
b, l, d = text_encoder_hidden_states.shape
348-
text_proj_in_factor = d // self.caption_channels
349-
text_encoder_hidden_states = text_encoder_hidden_states.reshape(b, l, self.caption_channels, text_proj_in_factor)
350-
else:
351-
b, l, _, _ = text_encoder_hidden_states.shape
352-
353-
if self.per_modality_projections:
354-
# LTX-2.3
355-
# per_token_rms_norm
356-
variance = jnp.mean(text_encoder_hidden_states**2, axis=2, keepdims=True)
357-
norm_text_encoder_hidden_states = text_encoder_hidden_states * jax.lax.rsqrt(variance + 1e-6)
358-
359-
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.reshape(b, l, -1)
360-
361-
bool_mask = (attention_mask > 0.5).astype(jnp.float32)[..., None]
362-
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
363-
364-
# Rescale norms
365-
video_scale_factor = jnp.sqrt(self.video_hidden_dim / self.caption_channels)
366-
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
367-
audio_scale_factor = jnp.sqrt(self.audio_hidden_dim / self.caption_channels)
368-
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
369-
370-
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
371-
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
372-
else:
373-
raise NotImplementedError("LTX-2.0 path in LTX2TextConnectors not fully implemented yet.")
374-
375-
video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, attention_mask)
376-
audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, attention_mask)
377-
378-
return video_text_embedding, audio_text_embedding, video_attn_mask

0 commit comments

Comments
 (0)