Skip to content

Commit ebd4527

Browse files
committed
shard out_kernel as (heads, None) for ltx2
1 parent e7aceb3 commit ebd4527

2 files changed

Lines changed: 240 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ def __init__(
363363
qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",))
364364

365365
# Out kernel: [in_features (heads), out_features (embed)]
366-
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed"))
366+
out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", None))
367367
# Out bias: [out_features (embed)]
368-
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",))
368+
out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), (None,))
369369

370370
# Norm scales
371371
norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",))
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""
2+
Script to tune flash block sizes for LTX2 model in MaxDiffusion.
3+
"""
4+
import os
5+
import time
6+
import jax
7+
import jax.numpy as jnp
8+
import flax
9+
from flax import nnx
10+
from jax.sharding import Mesh
11+
from flax.linen import partitioning as nn_partitioning
12+
import flax.linen as nn
13+
from maxdiffusion import pyconfig
14+
from maxdiffusion.max_utils import create_device_mesh, device_put_replicated
15+
from maxdiffusion.models.ltx2.transformer_ltx2 import LTX2VideoTransformer3DModel
16+
from maxdiffusion.common_types import BlockSizes
17+
18+
jax.config.update("jax_use_shardy_partitioner", True)
19+
try:
20+
flax.config.update("flax_always_shard_variable", False)
21+
except LookupError:
22+
pass
23+
24+
def create_model(config, mesh, block_sizes):
25+
key = jax.random.key(42) # Fixed seed for identical weights
26+
rngs = nnx.Rngs(key)
27+
28+
def model_factory(rngs):
29+
return LTX2VideoTransformer3DModel(
30+
rngs=rngs,
31+
in_channels=128,
32+
out_channels=128,
33+
patch_size=1,
34+
patch_size_t=1,
35+
num_attention_heads=32,
36+
attention_head_dim=128,
37+
cross_attention_dim=4096,
38+
caption_channels=3840,
39+
audio_in_channels=128,
40+
audio_out_channels=128,
41+
audio_num_attention_heads=32,
42+
audio_attention_head_dim=64,
43+
audio_cross_attention_dim=2048,
44+
num_layers=48, # Full model
45+
mesh=mesh,
46+
attention_kernel="flash",
47+
flash_block_sizes=block_sizes,
48+
flash_min_seq_length=4096,
49+
dtype=jnp.bfloat16,
50+
weights_dtype=jnp.bfloat16,
51+
)
52+
53+
# Use eval_shape to avoid allocating full parameters on default device
54+
transformer = nnx.eval_shape(model_factory, rngs=rngs)
55+
graphdef, state, rest_of_state = nnx.split(transformer, nnx.Param, ...)
56+
57+
logical_state_spec = nnx.get_partition_spec(state)
58+
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
59+
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
60+
61+
flat_state = dict(nnx.to_flat_state(state))
62+
for path, shape_dtype in flat_state.items():
63+
sharding = logical_state_sharding[path].value
64+
val = jnp.zeros(shape_dtype.shape, dtype=shape_dtype.dtype)
65+
flat_state[path].value = device_put_replicated(val, sharding)
66+
67+
state = nnx.from_flat_state(flat_state)
68+
69+
def init_dummy_shape(node):
70+
if isinstance(node, jax.ShapeDtypeStruct):
71+
if jax.dtypes.issubdtype(node.dtype, jax.dtypes.prng_key):
72+
dummy_key = jax.random.key(0)
73+
if node.shape == ():
74+
return dummy_key
75+
return jax.random.split(dummy_key, node.shape[0])
76+
return jnp.zeros(node.shape, dtype=node.dtype)
77+
return node
78+
79+
rest_of_state = jax.tree_util.tree_map(init_dummy_shape, rest_of_state)
80+
81+
model = nnx.merge(graphdef, state, rest_of_state)
82+
return model
83+
84+
def run_tuning(block_q=None, block_kv_compute=None, block_kv=None):
85+
# Initialize config
86+
script_dir = os.path.dirname(os.path.abspath(__file__))
87+
config_path = os.path.join(script_dir, "configs", "ltx2_video.yml")
88+
89+
print(f"Loading config from: {config_path}")
90+
pyconfig.initialize([
91+
None,
92+
config_path,
93+
"per_device_batch_size=0.125",
94+
"ici_data_parallelism=2",
95+
"ici_context_parallelism=4",
96+
"ici_tensor_parallelism=1",
97+
"ici_fsdp_parallelism=1",
98+
"attention=flash"
99+
], unittest=True)
100+
config = pyconfig.config
101+
102+
# Create mesh
103+
devices_array = create_device_mesh(config)
104+
mesh = Mesh(devices_array, config.mesh_axes)
105+
print(f"Mesh created: {mesh}")
106+
107+
# Define search space for elaborate grid search (multiples of 256)
108+
block_q_options = [512, 1024, 1536, 2048]
109+
block_kv_compute_options = [512, 1024, 1536, 2048]
110+
block_kv_options = [1024, 2048, 3072, 4096]
111+
112+
if block_q is not None:
113+
block_q_options = [block_q]
114+
if block_kv_compute is not None:
115+
block_kv_compute_options = [block_kv_compute]
116+
if block_kv is not None:
117+
block_kv_options = [block_kv]
118+
119+
best_time = float('inf')
120+
best_comb = None
121+
122+
# Dummy inputs
123+
# User runs with per_device_batch_size = 0.125, which gives global_batch_size = 1.
124+
# But CFG (Classifier-Free Guidance) doubles the batch size to 2.
125+
# So we use global_batch_size = 2 here to match the actual tensor shape.
126+
per_device_batch_size = 0.125
127+
global_batch_size = 2
128+
seq_len = 6144 # Updated to match user's actual sequence length
129+
audio_seq_len = 126
130+
131+
hidden_states = jnp.zeros((global_batch_size, seq_len, 128), dtype=jnp.bfloat16)
132+
audio_hidden_states = jnp.zeros((global_batch_size, audio_seq_len, 128), dtype=jnp.bfloat16)
133+
timestep = jnp.ones((global_batch_size,), dtype=jnp.bfloat16)
134+
encoder_hidden_states = jnp.zeros((global_batch_size, 128, 3840), dtype=jnp.bfloat16)
135+
audio_encoder_hidden_states = jnp.zeros((global_batch_size, 128, 3840), dtype=jnp.bfloat16)
136+
137+
for bq in block_q_options:
138+
for bkv_c in block_kv_compute_options:
139+
for bkv in block_kv_options:
140+
# Enforce that block_kv must be a multiple of block_kv_compute
141+
if bkv % bkv_c != 0:
142+
continue
143+
144+
print(f"\nTrying combination: block_q={bq}, block_kv_compute={bkv_c}, block_kv={bkv}")
145+
146+
block_sizes = BlockSizes(
147+
block_q=bq,
148+
block_kv_compute=bkv_c,
149+
block_kv=bkv,
150+
block_q_dkv=bq,
151+
block_kv_dkv=bkv,
152+
block_kv_dkv_compute=bkv_c,
153+
block_q_dq=None,
154+
block_kv_dq=None,
155+
use_fused_bwd_kernel=True
156+
)
157+
158+
try:
159+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
160+
model = create_model(config, mesh, block_sizes)
161+
162+
graphdef, state = nnx.split(model)
163+
164+
@nnx.jit
165+
def step_fn(graphdef, state, hidden_states, audio_hidden_states, encoder_hidden_states, audio_encoder_hidden_states):
166+
model_local = nnx.merge(graphdef, state)
167+
return model_local(
168+
hidden_states=hidden_states,
169+
audio_hidden_states=audio_hidden_states,
170+
encoder_hidden_states=encoder_hidden_states,
171+
audio_encoder_hidden_states=audio_encoder_hidden_states,
172+
timestep=timestep,
173+
num_frames=6,
174+
height=32,
175+
width=32,
176+
audio_num_frames=audio_seq_len,
177+
return_dict=True,
178+
)
179+
180+
# Warmup / Compilation
181+
print(" Compiling...")
182+
res = step_fn(graphdef, state, hidden_states, audio_hidden_states, encoder_hidden_states, audio_encoder_hidden_states)
183+
jax.block_until_ready(res)
184+
185+
# Run 12 steps
186+
times = []
187+
for i in range(12):
188+
start = time.time()
189+
res = step_fn(graphdef, state, hidden_states, audio_hidden_states, encoder_hidden_states, audio_encoder_hidden_states)
190+
jax.block_until_ready(res)
191+
end = time.time()
192+
times.append(end - start)
193+
print(f" Step {i}: {end - start:.4f}s")
194+
195+
avg_time = sum(times[2:]) / 10
196+
print(f" Average time (last 10 steps): {avg_time:.4f}s")
197+
198+
# Append to a results file to track across processes
199+
results_file = "flash_attention_tuning_results.csv"
200+
file_exists = os.path.exists(results_file)
201+
with open(results_file, "a") as f:
202+
if not file_exists:
203+
f.write("block_q,block_kv_compute,block_kv,average_time\n")
204+
f.write(f"{bq},{bkv_c},{bkv},{avg_time:.4f}\n")
205+
206+
if avg_time < best_time:
207+
best_time = avg_time
208+
best_comb = (bq, bkv_c, bkv)
209+
210+
except Exception as e:
211+
print(f" invalid combination. Error: {e}")
212+
import traceback
213+
traceback.print_exc()
214+
finally:
215+
# Clear memory to avoid OOM between iterations
216+
if 'model' in locals():
217+
del model
218+
import gc
219+
gc.collect()
220+
jax.clear_caches()
221+
222+
print(f"\n{'='*40}")
223+
if best_comb:
224+
print(f"Best combination: block_q={best_comb[0]}, block_kv_compute={best_comb[1]}, block_kv={best_comb[2]}")
225+
print(f"Best average time: {best_time:.4f}s")
226+
else:
227+
print("No valid combination found.")
228+
print(f"{'='*40}")
229+
230+
if __name__ == "__main__":
231+
import argparse
232+
parser = argparse.ArgumentParser()
233+
parser.add_argument("--block_q", type=int, default=None)
234+
parser.add_argument("--block_kv_compute", type=int, default=None)
235+
parser.add_argument("--block_kv", type=int, default=None)
236+
args = parser.parse_args()
237+
238+
run_tuning(block_q=args.block_q, block_kv_compute=args.block_kv_compute, block_kv=args.block_kv)

0 commit comments

Comments
 (0)