Skip to content

Commit 5b055ea

Browse files
committed
Revert "Integrate tokamax ring attention as optional attention kernel for WAN 2.1"
This reverts commit f68c7b0.
1 parent d128e32 commit 5b055ea

2 files changed

Lines changed: 40 additions & 72 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 36 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
2828
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
30-
from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
3130
from einops import rearrange
3231
from .. import common_types, max_logging
3332

@@ -306,92 +305,62 @@ def wrap_flash_attention(query, key, value):
306305
mask=mask,
307306
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
308307
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
309-
save_residuals=True if "ring" in attention_kernel else False,
310-
)
311-
elif attention_kernel == "tokamax_ring":
312-
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
313-
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
314-
mask=mask,
315-
is_mqa=False,
316-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
317-
save_residuals=True,
318-
ring_axis="fsdp",
308+
save_residuals=True if attention_kernel == "ring" else False,
319309
)
320310
else:
321311
splash_kernel = splash_attention_kernel.make_splash_mha(
322312
mask=multi_head_mask,
323313
head_shards=1, # the sizes of the axis is sharding over heads
324314
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
325315
block_sizes=block_sizes,
326-
save_residuals=True if "ring" in attention_kernel else False,
316+
save_residuals=True if attention_kernel == "ring" else False,
327317
residual_checkpoint_name=residual_checkpoint_name
328318
)
319+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
329320

330-
if attention_kernel == "tokamax_ring":
331-
# For tokamax_ring, use the kernel directly without vmap
332-
# The ring attention kernel handles the ring topology internally
333-
if not mask_padding_tokens:
334-
segment_ids = None
335-
attention_output = splash_kernel(
336-
fwd_mask_info=None,
337-
dkv_mask_info=None,
338-
q=query,
339-
k=key,
340-
v=value,
341-
segment_ids=segment_ids,
342-
is_mqa=False,
343-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
344-
mask_value=-jnp.inf,
345-
mask_function=None,
346-
fwd_mask_sparsity=1.0,
347-
save_residuals=True,
348-
)
321+
if not mask_padding_tokens:
322+
segment_ids = None
323+
if attention_kernel in ["flash", "tokamax_flash"]:
324+
attention_output = vmapped_splash(query, key, value, segment_ids)
349325
else:
350-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
351-
352-
if not mask_padding_tokens:
353-
segment_ids = None
354-
if attention_kernel in ["flash", "tokamax_flash"]:
355-
attention_output = vmapped_splash(query, key, value, segment_ids)
356-
else:
357-
if num_fsdp_shards > 1:
358-
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
359-
m = lse.astype(jnp.float32)
360-
l = jnp.exp(lse - m)
361-
o = out.astype(jnp.float32) * l[..., None]
326+
if num_fsdp_shards > 1:
327+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
328+
m = lse.astype(jnp.float32)
329+
l = jnp.exp(lse - m)
330+
o = out.astype(jnp.float32) * l[..., None]
362331

363-
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
332+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
364333

365-
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
366-
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
334+
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
335+
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
367336

368-
def ring_scan_body(carry, _):
369-
m, l, o, k_current, v_current = carry
370-
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
371-
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
337+
def ring_scan_body(carry, _):
338+
m, l, o, k_current, v_current = carry
339+
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
340+
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
372341

373-
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
342+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
374343

375-
m_chunk = lse_chunk.astype(jnp.float32)
376-
m_old = m
377-
m = jnp.maximum(m_old, m_chunk)
344+
m_chunk = lse_chunk.astype(jnp.float32)
345+
m_old = m
346+
m = jnp.maximum(m_old, m_chunk)
378347

379-
exp_m_diff = jnp.exp(m_old - m)
380-
exp_m_chunk_diff = jnp.exp(m_chunk - m)
348+
exp_m_diff = jnp.exp(m_old - m)
349+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
381350

382-
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
383-
o = o * exp_m_diff[..., None]
384-
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
351+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
352+
o = o * exp_m_diff[..., None]
353+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
385354

386-
# Return the updated state for the next iteration
387-
return (m, l, o, k_next, v_next), None
355+
# Return the updated state for the next iteration
356+
return (m, l, o, k_next, v_next), None
388357

389-
initial_carry = (m, l, o, k1, v1)
390-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
358+
initial_carry = (m, l, o, k1, v1)
359+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
391360

392-
attention_output = o_final / l_final[..., None]
393-
else:
394-
raise ValueError("ring attention requires fsdp > 1")
361+
attention_output = o_final / l_final[..., None]
362+
else:
363+
raise ValueError("ring attention requires fsdp > 1")
395364

396365
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
397366

@@ -567,7 +536,7 @@ def _apply_attention(
567536
mask_padding_tokens=mask_padding_tokens,
568537
residual_checkpoint_name=residual_checkpoint_name,
569538
)
570-
elif "ring" in attention_kernel:
539+
elif attention_kernel == "ring":
571540
return _tpu_flash_attention(
572541
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
573542
mask_padding_tokens=mask_padding_tokens,
@@ -578,7 +547,6 @@ def _apply_attention(
578547
raise ValueError(f"Unexpected attention kernel {attention_kernel=}.")
579548

580549

581-
582550
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
583551
"""Multi-head dot product attention with a limited number of queries."""
584552
num_kv, num_heads, k_features = key.shape[-3:]

src/maxdiffusion/pyconfig.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def user_init(raw_keys):
195195

196196
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
197197
# Verify qkv is sharded across sequence.
198-
if "ring" in raw_keys["attention"] or raw_keys["attention_sharding_uniform"]:
199-
max_logging.log(f"Adding sequence sharding to q and kv if not already present because '{raw_keys['attention']}' contains 'ring' or {raw_keys['attention_sharding_uniform']} is set.")
198+
if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]:
199+
max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.")
200200
logical_axis_rules = list(raw_keys["logical_axis_rules"])
201201
max_logging.log(f"Initial logical axis rules: {logical_axis_rules}")
202202
new_rules = []
@@ -206,12 +206,12 @@ def user_init(raw_keys):
206206
logical_axis_rules.append(q_seq_sharding)
207207
if kv_seq_sharding not in logical_axis_rules:
208208
logical_axis_rules.append(kv_seq_sharding)
209-
if "ring" in raw_keys["attention"]:
209+
if raw_keys["attention"] == "ring":
210210
for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES:
211211
if ring_attention_axis_rule not in logical_axis_rules:
212212
max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}")
213213
new_rules.append(ring_attention_axis_rule)
214-
else: # attention contains 'flash' but sequence parallel sharding requested for both self and cross attention
214+
else: # attention =flash but sequence parallel sharding requested for both self and cross attention
215215
for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES:
216216
if seq_parallel_axis_rule not in logical_axis_rules:
217217
max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}")

0 commit comments

Comments
 (0)