Skip to content

Commit fcf64e3

Browse files
committed
Add LTX2.3 Model support
1 parent c5bb862 commit fcf64e3

9 files changed

Lines changed: 853 additions & 13 deletions

File tree

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
attention: 'flash'
5+
a2v_attention_kernel: 'flash'
6+
v2a_attention_kernel: 'dot_product'
7+
attention_sharding_uniform: True
8+
precision: 'bf16'
9+
scan_layers: True
10+
names_which_can_be_saved: []
11+
names_which_can_be_offloaded: []
12+
remat_policy: "NONE"
13+
14+
jax_cache_dir: ''
15+
weights_dtype: 'bfloat16'
16+
activations_dtype: 'bfloat16'
17+
18+
run_name: 'ltx2_inference'
19+
output_dir: ''
20+
config_path: ''
21+
save_config_to_gcs: False
22+
23+
#Checkpoints
24+
max_sequence_length: 1024
25+
sampler: "from_checkpoint"
26+
27+
# Generation parameters
28+
global_batch_size_to_train_on: 1
29+
num_inference_steps: 40
30+
guidance_scale: 3.0
31+
fps: 24
32+
pipeline_type: multi-scale
33+
prompt: "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie."
34+
negative_prompt: "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
35+
height: 512
36+
width: 768
37+
decode_timestep: 0.05
38+
decode_noise_scale: 0.025
39+
num_frames: 121
40+
quantization: "int8"
41+
seed: 10
42+
#parallelism
43+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
44+
logical_axis_rules: [
45+
['batch', ['data', 'fsdp']],
46+
['activation_batch', ['data', 'fsdp']],
47+
['activation_self_attn_heads', ['context', 'tensor']],
48+
['activation_cross_attn_q_length', ['context', 'tensor']],
49+
['activation_length', 'context'],
50+
['activation_heads', 'tensor'],
51+
['mlp','tensor'],
52+
['embed', ['context', 'fsdp']],
53+
['heads', 'tensor'],
54+
['norm', 'tensor'],
55+
['conv_batch', ['data', 'context', 'fsdp']],
56+
['out_channels', 'tensor'],
57+
['conv_out', 'context'],
58+
]
59+
data_sharding: ['data', 'fsdp', 'context', 'tensor']
60+
61+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
62+
dcn_fsdp_parallelism: -1
63+
64+
flash_block_sizes: {
65+
block_q: 2048,
66+
block_kv: 2048,
67+
block_kv_compute: 1024,
68+
block_q_dkv: 2048,
69+
block_kv_dkv: 2048,
70+
block_kv_dkv_compute: 2048,
71+
use_fused_bwd_kernel: True,
72+
}
73+
flash_min_seq_length: 4096
74+
dcn_context_parallelism: 1
75+
dcn_tensor_parallelism: 1
76+
ici_data_parallelism: 1
77+
ici_fsdp_parallelism: 1
78+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
79+
ici_tensor_parallelism: 1
80+
enable_profiler: False
81+
82+
replicate_vae: False
83+
84+
allow_split_physical_axes: False
85+
learning_rate_schedule_steps: -1
86+
max_train_steps: 500
87+
pretrained_model_name_or_path: 'Lightricks/LTX-2.3'
88+
model_name: "ltx2.3"
89+
model_type: "T2V"
90+
unet_checkpoint: ''
91+
checkpoint_dir: ""
92+
dataset_name: ''
93+
train_split: 'train'
94+
dataset_type: 'tfrecord'
95+
cache_latents_text_encoder_outputs: True
96+
per_device_batch_size: 1.0
97+
compile_topology_num_slices: -1
98+
quantization_local_shard_count: -1
99+
use_qwix_quantization: False
100+
weight_quantization_calibration_method: "absmax"
101+
act_quantization_calibration_method: "absmax"
102+
bwd_quantization_calibration_method: "absmax"
103+
qwix_module_path: ".*"
104+
jit_initializers: True
105+
enable_single_replica_ckpt_restoring: False
106+
seed: 0
107+
audio_format: "s16"

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def __init__(
349349
rope_type: str = "interleaved",
350350
flash_block_sizes: BlockSizes = None,
351351
flash_min_seq_length: int = 4096,
352+
gated_attn: bool = False,
352353
):
353354
self.heads = heads
354355
self.rope_type = rope_type
@@ -426,6 +427,19 @@ def __init__(
426427
else:
427428
self.dropout_layer = None
428429

430+
if gated_attn:
431+
self.to_gate_logits = nnx.Linear(
432+
query_dim,
433+
heads,
434+
use_bias=True,
435+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
436+
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)),
437+
rngs=rngs,
438+
dtype=dtype,
439+
)
440+
else:
441+
self.to_gate_logits = None
442+
429443
self.attention_op = NNXAttentionOp(
430444
mesh=mesh,
431445
attention_kernel=attention_kernel,
@@ -489,6 +503,14 @@ def __call__(
489503
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
490504
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
491505

506+
if getattr(self, "to_gate_logits", None) is not None:
507+
gate_logits = self.to_gate_logits(hidden_states)
508+
b, s, _ = attn_output.shape
509+
attn_output = attn_output.reshape(b, s, self.heads, self.dim_head)
510+
gates = 2.0 * jax.nn.sigmoid(gate_logits)
511+
attn_output = attn_output * jnp.expand_dims(gates, axis=-1)
512+
attn_output = attn_output.reshape(b, s, -1)
513+
492514
# 7. Output Projection
493515
hidden_states = self.to_out(attn_output)
494516

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ def rename_for_ltx2_transformer(key):
5757
if "to_out_0" in key:
5858
key = key.replace("to_out_0", "to_out")
5959

60+
# LTX-2.3 specific mappings
61+
if "prompt_adaln" in key:
62+
key = key.replace("prompt_adaln", "caption_projection")
63+
if "audio_prompt_adaln" in key:
64+
key = key.replace("audio_prompt_adaln", "audio_caption_projection")
65+
if "video_text_proj_in" in key:
66+
key = key.replace("video_text_proj_in", "feature_extractor.video_linear")
67+
if "audio_text_proj_in" in key:
68+
key = key.replace("audio_text_proj_in", "feature_extractor.audio_linear")
69+
6070
return key
6171

6272

@@ -269,6 +279,11 @@ def rename_for_ltx2_vocoder(key):
269279
key = key.replace("ups.", "upsamplers.")
270280
key = key.replace("resblocks", "resnets")
271281
key = key.replace("conv_post", "conv_out")
282+
283+
# LTX-2.3 specific mappings for Vocoder
284+
if "downsample" in key and "lowpass" not in key:
285+
key = key.replace("downsample", "downsample.lowpass")
286+
272287
return key
273288

274289

src/maxdiffusion/models/ltx2/text_encoders/embeddings_connector_ltx2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
attention_kernel: str = "flash",
3838
mesh: jax.sharding.Mesh = None,
3939
rngs: nnx.Rngs = None,
40+
gated_attn: bool = False,
4041
):
4142
self.attn1 = LTX2Attention(
4243
query_dim=dim,
@@ -48,6 +49,7 @@ def __init__(
4849
attention_kernel=attention_kernel,
4950
mesh=mesh,
5051
rngs=rngs,
52+
gated_attn=gated_attn,
5153
)
5254
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim, activation_fn="gelu_tanh")
5355
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
@@ -92,6 +94,7 @@ def __init__(
9294
attention_kernel: str = "flash",
9395
mesh: jax.sharding.Mesh = None,
9496
rngs: nnx.Rngs = None,
97+
gated_attn: bool = False,
9598
):
9699
self.dim = input_dim
97100
self.heads = heads
@@ -117,6 +120,7 @@ def create_block(rngs):
117120
attention_kernel=attention_kernel,
118121
mesh=mesh,
119122
rngs=rngs,
123+
gated_attn=gated_attn,
120124
)
121125

122126
# Call the vmapped constructor

src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,21 @@ def __init__(
102102
output_dim: int,
103103
dtype: DType = jnp.float32,
104104
rngs: nnx.Rngs = None,
105+
per_modality_projections: bool = False,
106+
use_bias: bool = False,
105107
):
106108
"""
107109
Args:
108110
input_dim: Dimension of flattened hidden states (Gemma dim * Num layers).
109111
output_dim: Target dimension for diffusion conditioning.
110112
"""
111-
# LTX-2 uses bias=False for the projection
112-
self.linear = nnx.Linear(input_dim, output_dim, use_bias=False, dtype=dtype, rngs=rngs)
113+
self.per_modality_projections = per_modality_projections
114+
115+
if per_modality_projections:
116+
self.video_linear = nnx.Linear(input_dim, output_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
117+
self.audio_linear = nnx.Linear(input_dim, output_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
118+
else:
119+
self.linear = nnx.Linear(input_dim, output_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
113120

114121
def __call__(self, hidden_states: Union[Tuple[Array, ...], Array], attention_mask: Array) -> Array:
115122
"""
@@ -133,4 +140,7 @@ def __call__(self, hidden_states: Union[Tuple[Array, ...], Array], attention_mas
133140
x_norm = _norm_and_concat_padded_batch(x, attention_mask)
134141

135142
# 4. Projection
136-
return self.linear(x_norm)
143+
if self.per_modality_projections:
144+
return self.video_linear(x_norm), self.audio_linear(x_norm)
145+
else:
146+
return self.linear(x_norm)

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,23 @@ def __init__(
5757
attention_kernel: str = "flash",
5858
mesh: jax.sharding.Mesh = None,
5959
rngs: nnx.Rngs = None,
60+
per_modality_projections: bool = False,
61+
proj_bias: bool = False,
62+
video_gated_attn: bool = False,
63+
audio_gated_attn: bool = False,
6064
**kwargs,
6165
):
6266
input_dim = caption_channels * text_proj_in_factor
6367

68+
self.per_modality_projections = per_modality_projections
69+
6470
self.feature_extractor = LTX2GemmaFeatureExtractor(
6571
input_dim=input_dim,
6672
output_dim=caption_channels,
6773
dtype=dtype,
6874
rngs=rngs,
75+
per_modality_projections=per_modality_projections,
76+
use_bias=proj_bias,
6977
)
7078

7179
# Two independent connectors
@@ -82,6 +90,7 @@ def __init__(
8290
attention_kernel=attention_kernel,
8391
mesh=mesh,
8492
rngs=rngs,
93+
gated_attn=video_gated_attn,
8594
)
8695

8796
self.audio_embeddings_connector = Embeddings1DConnector(
@@ -97,6 +106,7 @@ def __init__(
97106
attention_kernel=attention_kernel,
98107
mesh=mesh,
99108
rngs=rngs,
109+
gated_attn=audio_gated_attn,
100110
)
101111

102112
def __call__(
@@ -113,7 +123,12 @@ def __call__(
113123
features = self.feature_extractor(hidden_states, attention_mask)
114124

115125
# 2. Parallel Connection
116-
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
117-
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
126+
if self.per_modality_projections:
127+
video_features, audio_features = features
128+
video_embeds, new_attention_mask = self.video_embeddings_connector(video_features, attention_mask)
129+
audio_embeds, _ = self.audio_embeddings_connector(audio_features, attention_mask)
130+
else:
131+
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
132+
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
118133

119134
return video_embeds, audio_embeds, new_attention_mask

0 commit comments

Comments
 (0)