Skip to content

Commit f68c7b0

Browse files
committed
Integrate tokamax ring attention as optional attention kernel for WAN 2.1
1 parent c29fdc4 commit f68c7b0

2 files changed

Lines changed: 72 additions & 40 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2828
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
30+
from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
3031
from einops import rearrange
3132
from .. import common_types, max_logging
3233

@@ -305,62 +306,92 @@ def wrap_flash_attention(query, key, value):
305306
mask=mask,
306307
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
307308
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
308-
save_residuals=True if attention_kernel == "ring" else False,
309+
save_residuals=True if "ring" in attention_kernel else False,
310+
)
311+
elif attention_kernel == "tokamax_ring":
312+
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
313+
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
314+
mask=mask,
315+
is_mqa=False,
316+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
317+
save_residuals=True,
318+
ring_axis="fsdp",
309319
)
310320
else:
311321
splash_kernel = splash_attention_kernel.make_splash_mha(
312322
mask=multi_head_mask,
313323
head_shards=1, # the sizes of the axis is sharding over heads
314324
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
315325
block_sizes=block_sizes,
316-
save_residuals=True if attention_kernel == "ring" else False,
326+
save_residuals=True if "ring" in attention_kernel else False,
317327
residual_checkpoint_name=residual_checkpoint_name
318328
)
319-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
320329

321-
if not mask_padding_tokens:
322-
segment_ids = None
323-
if attention_kernel in ["flash", "tokamax_flash"]:
324-
attention_output = vmapped_splash(query, key, value, segment_ids)
330+
if attention_kernel == "tokamax_ring":
331+
# For tokamax_ring, use the kernel directly without vmap
332+
# The ring attention kernel handles the ring topology internally
333+
if not mask_padding_tokens:
334+
segment_ids = None
335+
attention_output = splash_kernel(
336+
fwd_mask_info=None,
337+
dkv_mask_info=None,
338+
q=query,
339+
k=key,
340+
v=value,
341+
segment_ids=segment_ids,
342+
is_mqa=False,
343+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
344+
mask_value=-jnp.inf,
345+
mask_function=None,
346+
fwd_mask_sparsity=1.0,
347+
save_residuals=True,
348+
)
325349
else:
326-
if num_fsdp_shards > 1:
327-
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
328-
m = lse.astype(jnp.float32)
329-
l = jnp.exp(lse - m)
330-
o = out.astype(jnp.float32) * l[..., None]
350+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
331351

332-
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
352+
if not mask_padding_tokens:
353+
segment_ids = None
354+
if attention_kernel in ["flash", "tokamax_flash"]:
355+
attention_output = vmapped_splash(query, key, value, segment_ids)
356+
else:
357+
if num_fsdp_shards > 1:
358+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
359+
m = lse.astype(jnp.float32)
360+
l = jnp.exp(lse - m)
361+
o = out.astype(jnp.float32) * l[..., None]
333362

334-
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
335-
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
363+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
336364

337-
def ring_scan_body(carry, _):
338-
m, l, o, k_current, v_current = carry
339-
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
340-
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
365+
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
366+
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
341367

342-
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
368+
def ring_scan_body(carry, _):
369+
m, l, o, k_current, v_current = carry
370+
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
371+
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
343372

344-
m_chunk = lse_chunk.astype(jnp.float32)
345-
m_old = m
346-
m = jnp.maximum(m_old, m_chunk)
373+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
347374

348-
exp_m_diff = jnp.exp(m_old - m)
349-
exp_m_chunk_diff = jnp.exp(m_chunk - m)
375+
m_chunk = lse_chunk.astype(jnp.float32)
376+
m_old = m
377+
m = jnp.maximum(m_old, m_chunk)
350378

351-
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
352-
o = o * exp_m_diff[..., None]
353-
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
379+
exp_m_diff = jnp.exp(m_old - m)
380+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
354381

355-
# Return the updated state for the next iteration
356-
return (m, l, o, k_next, v_next), None
382+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
383+
o = o * exp_m_diff[..., None]
384+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
357385

358-
initial_carry = (m, l, o, k1, v1)
359-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
386+
# Return the updated state for the next iteration
387+
return (m, l, o, k_next, v_next), None
360388

361-
attention_output = o_final / l_final[..., None]
362-
else:
363-
raise ValueError("ring attention requires fsdp > 1")
389+
initial_carry = (m, l, o, k1, v1)
390+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
391+
392+
attention_output = o_final / l_final[..., None]
393+
else:
394+
raise ValueError("ring attention requires fsdp > 1")
364395

365396
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
366397

@@ -536,7 +567,7 @@ def _apply_attention(
536567
mask_padding_tokens=mask_padding_tokens,
537568
residual_checkpoint_name=residual_checkpoint_name,
538569
)
539-
elif attention_kernel == "ring":
570+
elif "ring" in attention_kernel:
540571
return _tpu_flash_attention(
541572
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
542573
mask_padding_tokens=mask_padding_tokens,
@@ -547,6 +578,7 @@ def _apply_attention(
547578
raise ValueError(f"Unexpected attention kernel {attention_kernel=}.")
548579

549580

581+
550582
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
551583
"""Multi-head dot product attention with a limited number of queries."""
552584
num_kv, num_heads, k_features = key.shape[-3:]

src/maxdiffusion/pyconfig.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def user_init(raw_keys):
195195

196196
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
197197
# Verify qkv is sharded across sequence.
198-
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
199-
max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.")
198+
if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]:
199+
max_logging.log(f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set.")
200200
logical_axis_rules = list(raw_keys["logical_axis_rules"])
201201
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
202202
new_rules = []
@@ -206,12 +206,12 @@ def user_init(raw_keys):
206206
logical_axis_rules.append(q_seq_sharding)
207207
if kv_seq_sharding not in logical_axis_rules:
208208
logical_axis_rules.append(kv_seq_sharding)
209-
if raw_keys["attention"] == "ring":
209+
if "ring" in raw_keys["attention"]:
210210
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
211211
if ring_attention_axis_rule not in logical_axis_rules:
212212
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
213213
new_rules.append(ring_attention_axis_rule)
214-
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
214+
else: # attention contains 'flash' but sequence parallel sharding requested for both self and cross attention
215215
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
216216
if seq_parallel_axis_rule not in logical_axis_rules:
217217
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")

0 commit comments

Comments
 (0)