Skip to content

Commit c159b84

Browse files
committed
refactor for ltx2.3
1 parent 3a7021c commit c159b84

3 files changed

Lines changed: 169 additions & 41 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
"model.diffusion_model.": "",
1818
"connectors.": "",
1919
"transformer_1d_blocks": "stacked_blocks",
20-
"text_embedding_projection.audio_aggregate_embed.weight": "feature_extractor.audio_linear.kernel",
21-
"text_embedding_projection.audio_aggregate_embed.bias": "feature_extractor.audio_linear.bias",
22-
"text_embedding_projection.video_aggregate_embed.weight": "feature_extractor.video_linear.kernel",
23-
"text_embedding_projection.video_aggregate_embed.bias": "feature_extractor.video_linear.bias",
20+
"text_embedding_projection.audio_aggregate_embed.weight": "audio_text_proj_in.kernel",
21+
"text_embedding_projection.audio_aggregate_embed.bias": "audio_text_proj_in.bias",
22+
"text_embedding_projection.video_aggregate_embed.weight": "video_text_proj_in.kernel",
23+
"text_embedding_projection.video_aggregate_embed.bias": "video_text_proj_in.bias",
2424
"q_norm": "norm_q",
2525
"k_norm": "norm_k",
2626
"norm_q.weight": "norm_q.scale",
@@ -35,8 +35,12 @@
3535
"ff.net.2.weight": "ff.net_2.kernel",
3636
"ff.net.2.bias": "ff.net_2.bias",
3737
"to_gate_logits.weight": "to_gate_logits.kernel",
38-
"audio_linear.weight": "audio_linear.kernel",
39-
"video_linear.weight": "video_linear.kernel",
38+
"audio_linear.weight": "audio_text_proj_in.kernel",
39+
"audio_linear.bias": "audio_text_proj_in.bias",
40+
"video_linear.weight": "video_text_proj_in.kernel",
41+
"video_linear.bias": "video_text_proj_in.bias",
42+
"video_embeddings_connector": "video_connector",
43+
"audio_embeddings_connector": "audio_connector",
4044
}
4145

4246
def load_connectors_weights(

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,118 @@ def block_scan_fn(carry, block_module):
261261
hidden_states = self.final_norm(hidden_states)
262262

263263
return hidden_states, attention_mask
264+
265+
266+
class LTX2TextConnectors(nnx.Module):
267+
268+
def __init__(
269+
self,
270+
caption_channels: int = 3840,
271+
text_proj_in_factor: int = 49,
272+
video_connector_num_attention_heads: int = 30,
273+
video_connector_attention_head_dim: int = 128,
274+
video_connector_num_layers: int = 2,
275+
video_connector_num_learnable_registers: int = 128,
276+
video_gated_attn: bool = False,
277+
audio_connector_num_attention_heads: int = 30,
278+
audio_connector_attention_head_dim: int = 128,
279+
audio_connector_num_layers: int = 2,
280+
audio_connector_num_learnable_registers: int = 128,
281+
audio_gated_attn: bool = False,
282+
connector_rope_base_seq_len: int = 4096,
283+
rope_theta: float = 10000.0,
284+
rope_double_precision: bool = True,
285+
rope_type: str = "interleaved",
286+
per_modality_projections: bool = False,
287+
video_hidden_dim: int = 4096,
288+
audio_hidden_dim: int = 2048,
289+
proj_bias: bool = False,
290+
attention_kernel: str = "flash",
291+
mesh: jax.sharding.Mesh = None,
292+
rngs: nnx.Rngs = None,
293+
):
294+
text_encoder_dim = caption_channels * text_proj_in_factor
295+
self.per_modality_projections = per_modality_projections
296+
self.caption_channels = caption_channels
297+
self.video_hidden_dim = video_hidden_dim
298+
self.audio_hidden_dim = audio_hidden_dim
299+
300+
if per_modality_projections:
301+
self.video_text_proj_in = nnx.Linear(
302+
in_features=text_encoder_dim, out_features=video_hidden_dim, use_bias=proj_bias, rngs=rngs
303+
)
304+
self.audio_text_proj_in = nnx.Linear(
305+
in_features=text_encoder_dim, out_features=audio_hidden_dim, use_bias=proj_bias, rngs=rngs
306+
)
307+
else:
308+
self.text_proj_in = nnx.Linear(
309+
in_features=text_encoder_dim, out_features=caption_channels, use_bias=proj_bias, rngs=rngs
310+
)
311+
312+
self.video_connector = Embeddings1DConnector(
313+
input_dim=video_hidden_dim if per_modality_projections else caption_channels,
314+
heads=video_connector_num_attention_heads,
315+
head_dim=video_connector_attention_head_dim,
316+
layers=video_connector_num_layers,
317+
theta=rope_theta,
318+
num_learnable_registers=video_connector_num_learnable_registers,
319+
rope_type=rope_type,
320+
base_seq_len=connector_rope_base_seq_len,
321+
double_precision=rope_double_precision,
322+
attention_kernel=attention_kernel,
323+
mesh=mesh,
324+
rngs=rngs,
325+
gated_attn=video_gated_attn,
326+
)
327+
328+
self.audio_connector = Embeddings1DConnector(
329+
input_dim=audio_hidden_dim if per_modality_projections else caption_channels,
330+
heads=audio_connector_num_attention_heads,
331+
head_dim=audio_connector_attention_head_dim,
332+
layers=audio_connector_num_layers,
333+
theta=rope_theta,
334+
num_learnable_registers=audio_connector_num_learnable_registers,
335+
rope_type=rope_type,
336+
base_seq_len=connector_rope_base_seq_len,
337+
double_precision=rope_double_precision,
338+
attention_kernel=attention_kernel,
339+
mesh=mesh,
340+
rngs=rngs,
341+
gated_attn=audio_gated_attn,
342+
)
343+
344+
def __call__(self, text_encoder_hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array, Array]:
345+
346+
if text_encoder_hidden_states.ndim == 3:
347+
b, l, d = text_encoder_hidden_states.shape
348+
text_proj_in_factor = d // self.caption_channels
349+
text_encoder_hidden_states = text_encoder_hidden_states.reshape(b, l, self.caption_channels, text_proj_in_factor)
350+
else:
351+
b, l, _, _ = text_encoder_hidden_states.shape
352+
353+
if self.per_modality_projections:
354+
# LTX-2.3
355+
# per_token_rms_norm
356+
variance = jnp.mean(text_encoder_hidden_states**2, axis=2, keepdims=True)
357+
norm_text_encoder_hidden_states = text_encoder_hidden_states * jax.lax.rsqrt(variance + 1e-6)
358+
359+
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.reshape(b, l, -1)
360+
361+
bool_mask = (attention_mask > 0.5).astype(jnp.float32)[..., None]
362+
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
363+
364+
# Rescale norms
365+
video_scale_factor = jnp.sqrt(self.video_hidden_dim / self.caption_channels)
366+
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
367+
audio_scale_factor = jnp.sqrt(self.audio_hidden_dim / self.caption_channels)
368+
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
369+
370+
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
371+
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
372+
else:
373+
raise NotImplementedError("LTX-2.0 path in LTX2TextConnectors not fully implemented yet.")
374+
375+
video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, attention_mask)
376+
audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, attention_mask)
377+
378+
return video_text_embedding, audio_text_embedding, video_attn_mask

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ def __init__(
615615
flash_min_seq_length: int = 4096,
616616
gated_attn: bool = False,
617617
cross_attn_mod: bool = False,
618+
use_prompt_embeddings: bool = True,
618619
**kwargs,
619620
):
620621
self.in_channels = in_channels
@@ -649,6 +650,7 @@ def __init__(
649650
self.attention_out_bias = attention_out_bias
650651
self.rope_theta = rope_theta
651652
self.rope_double_precision = rope_double_precision
653+
self.use_prompt_embeddings = use_prompt_embeddings
652654
self.causal_offset = causal_offset
653655
self.timestep_scale_multiplier = timestep_scale_multiplier
654656
self.cross_attn_timestep_scale_multiplier = cross_attn_timestep_scale_multiplier
@@ -694,38 +696,42 @@ def __init__(
694696
)
695697

696698
# 2. Prompt embeddings
697-
if self.cross_attn_mod:
698-
self.caption_projection = NNXCombinedTimestepTextProjEmbeddings(
699-
rngs=rngs,
700-
in_features=self.caption_channels,
701-
hidden_size=self.cross_attention_dim,
702-
embedding_dim=self.cross_attention_dim,
703-
dtype=self.dtype,
704-
weights_dtype=self.weights_dtype,
705-
)
706-
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
707-
rngs=rngs,
708-
in_features=self.audio_caption_channels,
709-
hidden_size=self.audio_cross_attention_dim,
710-
embedding_dim=self.audio_cross_attention_dim,
711-
dtype=self.dtype,
712-
weights_dtype=self.weights_dtype,
713-
)
699+
if self.use_prompt_embeddings:
700+
if self.cross_attn_mod:
701+
self.caption_projection = NNXCombinedTimestepTextProjEmbeddings(
702+
rngs=rngs,
703+
in_features=self.caption_channels,
704+
hidden_size=self.cross_attention_dim,
705+
embedding_dim=self.cross_attention_dim,
706+
dtype=self.dtype,
707+
weights_dtype=self.weights_dtype,
708+
)
709+
self.audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings(
710+
rngs=rngs,
711+
in_features=self.audio_caption_channels,
712+
hidden_size=self.audio_cross_attention_dim,
713+
embedding_dim=self.audio_cross_attention_dim,
714+
dtype=self.dtype,
715+
weights_dtype=self.weights_dtype,
716+
)
717+
else:
718+
self.caption_projection = NNXPixArtAlphaTextProjection(
719+
rngs=rngs,
720+
in_features=self.caption_channels,
721+
hidden_size=inner_dim,
722+
dtype=self.dtype,
723+
weights_dtype=self.weights_dtype,
724+
)
725+
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
726+
rngs=rngs,
727+
in_features=self.audio_caption_channels,
728+
hidden_size=audio_inner_dim,
729+
dtype=self.dtype,
730+
weights_dtype=self.weights_dtype,
731+
)
714732
else:
715-
self.caption_projection = NNXPixArtAlphaTextProjection(
716-
rngs=rngs,
717-
in_features=self.caption_channels,
718-
hidden_size=inner_dim,
719-
dtype=self.dtype,
720-
weights_dtype=self.weights_dtype,
721-
)
722-
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
723-
rngs=rngs,
724-
in_features=self.audio_caption_channels,
725-
hidden_size=audio_inner_dim,
726-
dtype=self.dtype,
727-
weights_dtype=self.weights_dtype,
728-
)
733+
self.caption_projection = None
734+
self.audio_caption_projection = None
729735
# 3. Timestep Modulation Params and Embedding
730736
num_mod_params = 9 if self.cross_attn_mod else 6
731737
self.time_embed = LTX2AdaLayerNormSingle(
@@ -1050,11 +1056,14 @@ def __call__(
10501056
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.reshape(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
10511057

10521058
# 4. Prepare prompt embeddings
1053-
encoder_hidden_states = self.caption_projection(encoder_hidden_states, timestep)
1054-
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
1059+
if self.use_prompt_embeddings:
1060+
encoder_hidden_states = self.caption_projection(encoder_hidden_states, timestep)
1061+
encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, hidden_states.shape[-1])
10551062

1056-
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states, audio_timestep if audio_timestep is not None else timestep)
1057-
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
1063+
audio_encoder_hidden_states = self.audio_caption_projection(
1064+
audio_encoder_hidden_states, audio_timestep if audio_timestep is not None else timestep
1065+
)
1066+
audio_encoder_hidden_states = audio_encoder_hidden_states.reshape(batch_size, -1, audio_hidden_states.shape[-1])
10581067

10591068
# 5. Run transformer blocks
10601069
def scan_fn(carry, block):

0 commit comments

Comments
 (0)