Skip to content

Commit 5644740

Browse files
committed
add ring attention - inference only.
1 parent ec61456 commit 5644740

3 files changed

Lines changed: 95 additions & 70 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _unflatten_heads(tensor, heads):
112112
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
113113
return tensor
114114

115+
115116
def _reshape_data_for_flash(tensor, heads):
116117
"""
117118
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
@@ -122,6 +123,7 @@ def _reshape_data_for_flash(tensor, heads):
122123
tensor = _unflatten_heads(tensor, heads)
123124
return tensor
124125

126+
125127
def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
126128
"""
127129
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
@@ -171,6 +173,7 @@ def _tpu_flash_attention(
171173
axis_names_kv: AxisNames,
172174
flash_block_sizes: BlockSizes,
173175
dtype: jnp.dtype = jnp.float32,
176+
attention_kernel: str = "flash",
174177
) -> jax.Array:
175178
"""TPU Flash Attention"""
176179

@@ -179,7 +182,6 @@ def _tpu_flash_attention(
179182
if key.shape[1] != query.shape[1]:
180183
assert key.shape[1] % 128 == 0
181184
kv_max_block_size = key.shape[1]
182-
#q_max_block_size = kv_max_block_size
183185
else:
184186
kv_max_block_size = q_max_block_size
185187
if flash_block_sizes:
@@ -217,8 +219,14 @@ def wrap_flash_attention(query, key, value):
217219

218220
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
219221
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
220-
q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0)
221-
kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0)
222+
223+
q_padded_len = query.shape[2]
224+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
225+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
226+
227+
kv_padded_len = key.shape[2]
228+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
229+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
222230
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
223231

224232
# make_splash_mha is wrapped around shardmap and seq and head is already
@@ -228,51 +236,51 @@ def wrap_flash_attention(query, key, value):
228236
head_shards=1, # the sizes of the axis is sharding over heads
229237
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
230238
block_sizes=block_sizes,
231-
save_residuals=True
239+
save_residuals=True if attention_kernel == "ring" else False,
232240
)
233-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0,0,0, None))
241+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
234242

235-
def ring_scan_body(carry, _):
236-
m, l, o, k_current, v_current = carry
237-
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
238-
k_next = jax.lax.ppermute(k_current, axis_name='fsdp', perm=perm)
239-
v_next = jax.lax.ppermute(v_current, axis_name='fsdp', perm=perm)
243+
if attention_kernel == "flash":
244+
attention_output = vmapped_splash(query, key, value, segment_ids)
245+
else:
246+
if num_fsdp_shards > 1:
247+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
248+
m = lse.astype(jnp.float32)
249+
l = jnp.exp(lse - m)
250+
o = out.astype(jnp.float32) * l[..., None]
240251

241-
out_chunk, (lse_chunk,) = vmapped_splash(
242-
query, k_current, v_current, segment_ids
243-
)
252+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
244253

245-
m_chunk = lse_chunk.astype(jnp.float32)
246-
m_old = m
247-
m = jnp.maximum(m_old, m_chunk)
248-
249-
exp_m_diff = jnp.exp(m_old - m)
250-
exp_m_chunk_diff = jnp.exp(m_chunk - m)
254+
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
255+
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
251256

252-
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
253-
o = o * exp_m_diff[..., None]
254-
o += (exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32))
257+
def ring_scan_body(carry, _):
258+
m, l, o, k_current, v_current = carry
259+
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
260+
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
255261

256-
# Return the updated state for the next iteration
257-
return (m, l, o, k_next, v_next), None
262+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
258263

259-
lse_shape = query.shape[:-1]
260-
m_init = jnp.full(lse_shape, -jnp.inf, dtype=jnp.float32)
261-
l_init = jnp.zeros(lse_shape, dtype=jnp.float32)
262-
o_init = jnp.zeros_like(query, dtype=jnp.float32)
264+
m_chunk = lse_chunk.astype(jnp.float32)
265+
m_old = m
266+
m = jnp.maximum(m_old, m_chunk)
263267

264-
initial_carry = (m_init, l_init, o_init, key, value)
268+
exp_m_diff = jnp.exp(m_old - m)
269+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
265270

266-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(
267-
ring_scan_body,
268-
initial_carry,
269-
None,
270-
length=num_fsdp_shards
271-
)
271+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
272+
o = o * exp_m_diff[..., None]
273+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
274+
275+
# Return the updated state for the next iteration
276+
return (m, l, o, k_next, v_next), None
277+
278+
initial_carry = (m, l, o, k1, v1)
279+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
272280

273-
attention_output = o_final / l_final[..., None]
281+
attention_output = o_final / l_final[..., None]
274282

275-
return attention_output[:,:,:query_seq_len,:kv_size].astype(query.dtype)
283+
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
276284

277285
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
278286
# This warning might show up when doing model eval for example, when calculating model flops
@@ -433,6 +441,10 @@ def _apply_attention(
433441
return _tpu_flash_attention(
434442
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
435443
)
444+
elif attention_kernel == "ring":
445+
return _tpu_flash_attention(
446+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel
447+
)
436448
elif attention_kernel == "cudnn_flash_te":
437449
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
438450
else:

src/maxdiffusion/pyconfig.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from . import max_logging
2828
from . import max_utils
2929
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
30+
from maxdiffusion.common_types import LENGTH, KV_LENGTH
3031

3132

3233
def string_to_bool(s: str) -> bool:
@@ -175,6 +176,17 @@ def user_init(raw_keys):
175176
max_utils.write_config_raw_keys_for_gcs(raw_keys)
176177

177178
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
179+
# Verify qkv is sharded across sequence.
180+
if raw_keys["attention"] == "ring":
181+
logical_axis_rules = list(raw_keys["logical_axis_rules"])
182+
q_seq_sharding = (LENGTH, "fsdp")
183+
kv_seq_sharding = (KV_LENGTH, "fsdp")
184+
if q_seq_sharding not in logical_axis_rules:
185+
logical_axis_rules.append(q_seq_sharding)
186+
if kv_seq_sharding not in logical_axis_rules:
187+
logical_axis_rules.append(kv_seq_sharding)
188+
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
189+
178190
raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])
179191

180192
if raw_keys["learning_rate_schedule_steps"] == -1:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
255255
eval_data_iterator = self.load_dataset(mesh, is_training=False)
256256
eval_rng = jax.random.key(self.config.seed + step)
257257
eval_metrics = []
258-
# Loop indefinitely until the iterator is exhausted
258+
# Loop indefinitely until the iterator is exhausted
259259
while True:
260260
try:
261261
with mesh:
@@ -329,6 +329,7 @@ def loss_fn(params):
329329
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
330330
return new_state, scheduler_state, metrics, new_rng
331331

332+
332333
def eval_step(state, data, rng, scheduler_state, scheduler, config):
333334
"""
334335
Computes the evaluation loss for a single batch without updating model weights.
@@ -338,44 +339,44 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
338339
# This ensures the batch size is consistent, though it might be redundant
339340
# if the evaluation dataloader is already configured correctly.
340341
for k, v in data.items():
341-
data[k] = v[: config.global_batch_size_to_train_on, :]
342+
data[k] = v[: config.global_batch_size_to_train_on, :]
342343

343344
# The loss function logic is identical to training. We are evaluating the model's
344345
# ability to perform its core training objective (e.g., denoising).
345346
def loss_fn(params):
346-
# Reconstruct the model from its definition and parameters
347-
model = nnx.merge(state.graphdef, params, state.rest_of_state)
348-
349-
# Prepare inputs
350-
latents = data["latents"].astype(config.weights_dtype)
351-
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
352-
bsz = latents.shape[0]
353-
354-
# Sample random timesteps and noise, just as in a training step
355-
timesteps = jax.random.randint(
356-
timestep_rng,
357-
(bsz,),
358-
0,
359-
scheduler.config.num_train_timesteps,
360-
)
361-
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
362-
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
363-
364-
# Get the model's prediction
365-
model_pred = model(
366-
hidden_states=noisy_latents,
367-
timestep=timesteps,
368-
encoder_hidden_states=encoder_hidden_states,
369-
)
347+
# Reconstruct the model from its definition and parameters
348+
model = nnx.merge(state.graphdef, params, state.rest_of_state)
349+
350+
# Prepare inputs
351+
latents = data["latents"].astype(config.weights_dtype)
352+
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
353+
bsz = latents.shape[0]
354+
355+
# Sample random timesteps and noise, just as in a training step
356+
timesteps = jax.random.randint(
357+
timestep_rng,
358+
(bsz,),
359+
0,
360+
scheduler.config.num_train_timesteps,
361+
)
362+
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
363+
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
370364

371-
# Calculate the loss against the target
372-
training_target = scheduler.training_target(latents, noise, timesteps)
373-
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
374-
loss = (training_target - model_pred) ** 2
375-
loss = loss * training_weight
376-
loss = jnp.mean(loss)
365+
# Get the model's prediction
366+
model_pred = model(
367+
hidden_states=noisy_latents,
368+
timestep=timesteps,
369+
encoder_hidden_states=encoder_hidden_states,
370+
)
377371

378-
return loss
372+
# Calculate the loss against the target
373+
training_target = scheduler.training_target(latents, noise, timesteps)
374+
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
375+
loss = (training_target - model_pred) ** 2
376+
loss = loss * training_weight
377+
loss = jnp.mean(loss)
378+
379+
return loss
379380

380381
# --- Key Difference from train_step ---
381382
# Directly compute the loss without calculating gradients.

0 commit comments

Comments
 (0)