Skip to content

Commit 06f64b0

Browse files
committed
flash block sizes added in config for ltx2
1 parent 0becce3 commit 06f64b0

2 files changed

Lines changed: 170 additions & 1 deletion

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,13 @@ dcn_fsdp_parallelism: -1
6262
flash_block_sizes: {
6363
block_q: 1024,
6464
block_kv: 1024,
65-
block_kv_compute: 1024
65+
block_kv_compute: 1024,
66+
block_q_dkv: 1024,
67+
block_kv_dkv: 1024,
68+
block_kv_dkv_compute: 1024,
69+
block_q_dq: 1024,
70+
block_kv_dq: 1024,
71+
use_fused_bwd_kernel: True,
6672
}
6773
dcn_context_parallelism: 1
6874
dcn_tensor_parallelism: 1
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import sys
2+
import time
3+
from functools import partial
4+
5+
import jax
6+
import jax.numpy as jnp
7+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
8+
from flax import nnx
9+
10+
from maxdiffusion import pyconfig
11+
from maxdiffusion.utils import logging
12+
from maxdiffusion import max_utils
13+
from maxdiffusion.models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
14+
from maxdiffusion.maxdiffusion_utils import get_precision
15+
16+
logger = logging.get_logger(__name__)
17+
18+
def get_dummy_ltx2_inputs(batch_size, dtype):
19+
rng = jax.random.key(0)
20+
# LTX-2 121 frames 512x768 -> latent 16x16x24
21+
latents = jax.random.normal(rng, (batch_size, 128, 16, 16, 24), dtype=dtype)
22+
audio_latents = None
23+
timestep = jnp.array(500.0, dtype=jnp.float32)
24+
# Gemma dim=3072, sequence=128
25+
prompt_embeds = jax.random.normal(rng, (batch_size, 128, 3072), dtype=dtype)
26+
audio_prompt_embeds = None
27+
encoder_attention_mask = jnp.ones((batch_size, 128), dtype=jnp.int32)
28+
audio_encoder_attention_mask = None
29+
30+
return latents, audio_latents, timestep, prompt_embeds, audio_prompt_embeds, encoder_attention_mask, audio_encoder_attention_mask
31+
32+
def calibrate_fbs(config):
33+
devices_array = max_utils.create_device_mesh(config)
34+
mesh = Mesh(devices_array, config.mesh_axes)
35+
36+
rng = jax.random.key(config.seed)
37+
rngs = nnx.Rngs(rng)
38+
39+
# 1. Load config
40+
ltx2_config_dict = LTX2VideoTransformer3DModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer")
41+
if ltx2_config_dict.get("activation_fn") == "gelu-approximate":
42+
ltx2_config_dict["activation_fn"] = "gelu"
43+
44+
ltx2_config_dict["scan_layers"] = getattr(config, "scan_layers", True)
45+
ltx2_config_dict["mesh"] = mesh
46+
ltx2_config_dict["dtype"] = config.activations_dtype
47+
ltx2_config_dict["weights_dtype"] = config.weights_dtype
48+
ltx2_config_dict["attention_kernel"] = config.attention
49+
ltx2_config_dict["precision"] = get_precision(config)
50+
ltx2_config_dict["flash_block_sizes"] = max_utils.get_flash_block_sizes(config)
51+
ltx2_config_dict["remat_policy"] = config.remat_policy
52+
ltx2_config_dict["names_which_can_be_saved"] = config.names_which_can_be_saved
53+
ltx2_config_dict["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded
54+
55+
print(f"Creating model with flash_block_sizes: {ltx2_config_dict['flash_block_sizes']}")
56+
57+
with mesh:
58+
# Standard initialization
59+
transformer = LTX2VideoTransformer3DModel(**ltx2_config_dict, rngs=rngs)
60+
61+
# Shard the model
62+
graphdef, state, rest_of_state = nnx.split(transformer, nnx.Param, ...)
63+
def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules):
64+
vs.sharding_rules = logical_axis_rules
65+
return vs
66+
67+
p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=config.logical_axis_rules)
68+
state_sharded = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
69+
pspecs = nnx.get_partition_spec(state_sharded)
70+
sharded_state = jax.lax.with_sharding_constraint(state_sharded, pspecs)
71+
72+
from maxdiffusion.pipelines.ltx2.ltx2_pipeline import transformer_forward_pass
73+
74+
# Define forward_pass strictly bounded by parameters, just like pipeline does
75+
76+
# Batch size handling
77+
batch_size = config.global_batch_size_to_train_on
78+
latents, audio_latents, timestep, prompt_embeds, audio_prompt_embeds, encoder_attention_mask, audio_encoder_attention_mask = get_dummy_ltx2_inputs(batch_size, config.activations_dtype)
79+
80+
data_sharding = NamedSharding(mesh, P())
81+
if config.global_batch_size_to_train_on // config.per_device_batch_size == 0:
82+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
83+
84+
# Add unconditional latents for CFG
85+
double_latents = jnp.concatenate([latents, latents], axis=0)
86+
double_prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds], axis=0)
87+
double_encoder_attention_mask = jnp.concatenate([encoder_attention_mask, encoder_attention_mask], axis=0)
88+
89+
double_latents = jax.device_put(double_latents, data_sharding)
90+
timestep = jax.device_put(timestep, data_sharding)
91+
double_prompt_embeds = jax.device_put(double_prompt_embeds, data_sharding)
92+
double_encoder_attention_mask = jax.device_put(double_encoder_attention_mask, data_sharding)
93+
94+
print("Compiling transformer forward pass...")
95+
start_compile = time.perf_counter()
96+
97+
# Using 50 runs to ensure XLA completely settles
98+
num_runs = 50
99+
100+
# Provide exactly what transformer_forward_pass needs
101+
latent_num_frames = 16
102+
latent_height = 16
103+
latent_width = 24
104+
audio_num_frames = 0
105+
fps = 24.0
106+
107+
_ = transformer_forward_pass(
108+
graphdef, sharded_state, double_latents,
109+
None, # audio_latents
110+
timestep, double_prompt_embeds,
111+
None, # audio_encoder_hidden_states
112+
double_encoder_attention_mask,
113+
None, # audio_encoder_attention_mask
114+
do_classifier_free_guidance=True,
115+
guidance_scale=1.5,
116+
latent_num_frames=latent_num_frames,
117+
latent_height=latent_height,
118+
latent_width=latent_width,
119+
audio_num_frames=audio_num_frames,
120+
fps=fps
121+
)
122+
123+
# Ensure compiled
124+
import jax.tree_util as jtu
125+
jtu.tree_map(lambda x: x.block_until_ready() if hasattr(x, "block_until_ready") else x, _)
126+
127+
compile_time = time.perf_counter() - start_compile
128+
print(f"Compilation finished. Time: {compile_time:.4f}s")
129+
130+
# Benchmarking
131+
print(f"Starting Benchmarking ({num_runs} runs)...")
132+
total_time = 0.0
133+
134+
for i in range(num_runs):
135+
start = time.perf_counter()
136+
_ = transformer_forward_pass(
137+
graphdef, sharded_state, double_latents,
138+
None,
139+
timestep, double_prompt_embeds,
140+
None,
141+
double_encoder_attention_mask,
142+
None,
143+
do_classifier_free_guidance=True,
144+
guidance_scale=1.5,
145+
latent_num_frames=latent_num_frames,
146+
latent_height=latent_height,
147+
latent_width=latent_width,
148+
audio_num_frames=audio_num_frames,
149+
fps=fps
150+
)
151+
# block until ready
152+
jtu.tree_map(lambda x: x.block_until_ready() if hasattr(x, "block_until_ready") else x, _)
153+
154+
step_time = time.perf_counter() - start
155+
if i > 5: # Ignore first few runs for warmup
156+
total_time += step_time
157+
print(f"[Tuning] Run {i+1}/{num_runs} - E2E Step time: {step_time*1000:.2f} ms")
158+
159+
print(f"Average pure diffusion cycle (after warmup): {(total_time/(num_runs-6))*1000:.2f} ms")
160+
161+
if __name__ == "__main__":
162+
config = pyconfig.initialize(sys.argv)
163+
calibrate_fbs(config)

0 commit comments

Comments
 (0)