Skip to content

Commit d4950fb

Browse files
committed
ltx2.3 config file and params addition
1 parent c98002f commit d4950fb

4 files changed

Lines changed: 128 additions & 1 deletion

File tree

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
WAN2_1 = "wan2.1"
5353
WAN2_2 = "wan2.2"
5454
LTX2_VIDEO = "ltx2_video"
55+
LTX2_3 = "ltx2.3"
5556

5657
WAN_MODEL = WAN2_1
5758

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#hardware
2+
hardware: 'tpu'
3+
skip_jax_distributed_system: False
4+
attention: 'flash'
5+
a2v_attention_kernel: 'dot_product'
6+
v2a_attention_kernel: 'dot_product'
7+
attention_sharding_uniform: True
8+
precision: 'bf16'
9+
scan_layers: True
10+
names_which_can_be_saved: []
11+
names_which_can_be_offloaded: []
12+
remat_policy: "NONE"
13+
14+
jax_cache_dir: ''
15+
weights_dtype: 'bfloat16'
16+
activations_dtype: 'bfloat16'
17+
18+
run_name: 'ltx2_inference'
19+
output_dir: ''
20+
config_path: ''
21+
save_config_to_gcs: False
22+
23+
#Checkpoints
24+
max_sequence_length: 1024
25+
sampler: "from_checkpoint"
26+
27+
# Generation parameters (aligned with Diffusers LTX-2.3 docs: use_cross_timestep, modality + audio CFG)
28+
global_batch_size_to_train_on: 1
29+
num_inference_steps: 30
30+
guidance_scale: 3.0
31+
guidance_rescale: 0.7
32+
audio_guidance_scale: 7.0
33+
audio_guidance_rescale: 0.7
34+
stg_scale: 1.0
35+
audio_stg_scale: 1.0
36+
modality_scale: 1.0
37+
audio_modality_scale: 1.0
38+
use_cross_timestep: true
39+
spatio_temporal_guidance_blocks: [28]
40+
fps: 24
41+
pipeline_type: multi-scale
42+
prompt: "A man in a brightly lit room talks on a vintage telephone. In a low, heavy voice, he says, 'I understand. I won't call again. Goodbye.' He hangs up the receiver and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is brightly lit by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a dramatic movie."
43+
negative_prompt: "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
44+
height: 512
45+
width: 768
46+
decode_timestep: 0.05
47+
decode_noise_scale: 0.025
48+
noise_scale: 0.0
49+
num_frames: 121
50+
quantization: "int8"
51+
#parallelism
52+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
53+
logical_axis_rules: [
54+
['batch', ['data', 'fsdp']],
55+
['activation_batch', ['data', 'fsdp']],
56+
['activation_self_attn_heads', ['context', 'tensor']],
57+
['activation_cross_attn_q_length', ['context', 'tensor']],
58+
['activation_length', 'context'],
59+
['activation_heads', 'tensor'],
60+
['mlp','tensor'],
61+
['embed', ['context', 'fsdp']],
62+
['heads', 'tensor'],
63+
['norm', 'tensor'],
64+
['conv_batch', ['data', 'context', 'fsdp']],
65+
['out_channels', 'tensor'],
66+
['conv_out', 'context'],
67+
]
68+
data_sharding: ['data', 'fsdp', 'context', 'tensor']
69+
70+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
71+
dcn_fsdp_parallelism: -1
72+
73+
flash_block_sizes: {
74+
block_q: 2048,
75+
block_kv: 2048,
76+
block_kv_compute: 1024,
77+
block_q_dkv: 2048,
78+
block_kv_dkv: 2048,
79+
block_kv_dkv_compute: 2048,
80+
use_fused_bwd_kernel: True,
81+
}
82+
flash_min_seq_length: 4096
83+
dcn_context_parallelism: 1
84+
dcn_tensor_parallelism: 1
85+
ici_data_parallelism: 1
86+
ici_fsdp_parallelism: 1
87+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
88+
ici_tensor_parallelism: 1
89+
enable_profiler: False
90+
91+
replicate_vae: False
92+
93+
allow_split_physical_axes: False
94+
learning_rate_schedule_steps: -1
95+
max_train_steps: 500
96+
pretrained_model_name_or_path: 'dg845/LTX-2.3-Diffusers'
97+
model_name: "ltx2.3"
98+
model_type: "T2V"
99+
unet_checkpoint: ''
100+
checkpoint_dir: ""
101+
dataset_name: ''
102+
train_split: 'train'
103+
dataset_type: 'tfrecord'
104+
cache_latents_text_encoder_outputs: True
105+
per_device_batch_size: 1.0
106+
compile_topology_num_slices: -1
107+
quantization_local_shard_count: -1
108+
use_qwix_quantization: False
109+
weight_quantization_calibration_method: "absmax"
110+
act_quantization_calibration_method: "absmax"
111+
bwd_quantization_calibration_method: "absmax"
112+
qwix_module_path: ".*"
113+
jit_initializers: True
114+
enable_single_replica_ckpt_restoring: False
115+
seed: 10
116+
audio_format: "s16"

src/maxdiffusion/generate_ltx2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
9898
decode_timestep=getattr(config, "decode_timestep", 0.0),
9999
decode_noise_scale=getattr(config, "decode_noise_scale", None),
100100
max_sequence_length=getattr(config, "max_sequence_length", 1024),
101+
guidance_rescale=getattr(config, "guidance_rescale", 0.0),
102+
audio_guidance_scale=getattr(config, "audio_guidance_scale", None),
103+
audio_guidance_rescale=getattr(config, "audio_guidance_rescale", None),
104+
stg_scale=getattr(config, "stg_scale", 0.0),
105+
audio_stg_scale=getattr(config, "audio_stg_scale", None),
106+
modality_scale=getattr(config, "modality_scale", 1.0),
107+
audio_modality_scale=getattr(config, "audio_modality_scale", None),
108+
use_cross_timestep=getattr(config, "use_cross_timestep", None),
109+
noise_scale=getattr(config, "noise_scale", 1.0),
101110
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
102111
output_type=getattr(config, "upsampler_output_type", "pil"),
103112
)

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@
3434
WAN2_1,
3535
WAN2_2,
3636
LTX2_VIDEO,
37+
LTX2_3,
3738
RING_ATTENTION_AXIS_RULES,
3839
SEQUENCE_PARALLEL_AXIS_RULES,
3940
ULYSSES_ATTENTION_AXIS_RULES,
4041
)
4142

42-
_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO}
43+
_ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO, LTX2_3}
4344
_ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1}
4445

4546

0 commit comments

Comments
 (0)