Skip to content

Commit 70ce989

Browse files
committed
Fixing ring attention
1 parent d128e32 commit 70ce989

1 file changed

Lines changed: 67 additions & 56 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,34 @@
5757
CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH
5858

5959

60+
def _coerce_tokamax_block_sizes(block_sizes):
61+
# Tokamax requires fused bwd; convert if needed.
62+
if getattr(block_sizes, "use_fused_bwd_kernel", False):
63+
return block_sizes
64+
65+
# Fall back if some fields are missing.
66+
bq = block_sizes.block_q
67+
bkv = getattr(block_sizes, "block_kv", bq)
68+
bkv_compute = getattr(block_sizes, "block_kv_compute", bkv)
69+
bq_dkv = getattr(block_sizes, "block_q_dkv", bq)
70+
bkv_dkv = getattr(block_sizes, "block_kv_dkv", bkv)
71+
bkv_dkv_compute = getattr(block_sizes, "block_kv_dkv_compute", bkv_compute)
72+
return splash_attention_kernel.BlockSizes(
73+
block_q=bq,
74+
block_kv=bkv,
75+
block_kv_compute=bkv_compute,
76+
block_q_dkv=bq_dkv,
77+
block_kv_dkv=bkv_dkv,
78+
block_kv_dkv_compute=bkv_dkv_compute,
79+
block_q_dq=None,
80+
block_kv_dq=None,
81+
use_fused_bwd_kernel=True,
82+
)
83+
84+
6085
def _maybe_aqt_einsum(quant: Quant):
6186
return jnp.einsum if quant is None else quant.einsum()
6287

63-
6488
def _check_attention_inputs(query: Array, key: Array, value: Array) -> None:
6589
"""Check attention inputs."""
6690

@@ -231,9 +255,13 @@ def _tpu_flash_attention(
231255
kv_max_block_size = key.shape[1]
232256
else:
233257
kv_max_block_size = q_max_block_size
258+
234259
# ensure that for cross attention we override the block sizes.
235260
if flash_block_sizes and key.shape[1] == query.shape[1]:
236261
block_sizes = flash_block_sizes
262+
use_tokamax = attention_kernel in ["tokamax_flash", "tokamax_ring"]
263+
if use_tokamax:
264+
block_sizes = _coerce_tokamax_block_sizes(flash_block_sizes)
237265
else:
238266
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
239267
block_sizes = splash_attention_kernel.BlockSizes(
@@ -327,71 +355,52 @@ def wrap_flash_attention(query, key, value):
327355
residual_checkpoint_name=residual_checkpoint_name
328356
)
329357

330-
if attention_kernel == "tokamax_ring":
331-
# For tokamax_ring, use the kernel directly without vmap
332-
# The ring attention kernel handles the ring topology internally
333-
if not mask_padding_tokens:
334-
segment_ids = None
335-
attention_output = splash_kernel(
336-
fwd_mask_info=None,
337-
dkv_mask_info=None,
338-
q=query,
339-
k=key,
340-
v=value,
341-
segment_ids=segment_ids,
342-
is_mqa=False,
343-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
344-
mask_value=-jnp.inf,
345-
mask_function=None,
346-
fwd_mask_sparsity=1.0,
347-
save_residuals=True,
348-
)
349-
else:
350-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
351358

352-
if not mask_padding_tokens:
353-
segment_ids = None
354-
if attention_kernel in ["flash", "tokamax_flash"]:
355-
attention_output = vmapped_splash(query, key, value, segment_ids)
356-
else:
357-
if num_fsdp_shards > 1:
358-
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
359-
m = lse.astype(jnp.float32)
360-
l = jnp.exp(lse - m)
361-
o = out.astype(jnp.float32) * l[..., None]
359+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
362360

363-
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
361+
if not mask_padding_tokens:
362+
segment_ids = None
363+
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
364+
attention_output = vmapped_splash(query, key, value, segment_ids)
365+
else:
366+
if num_fsdp_shards > 1:
367+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
368+
m = lse.astype(jnp.float32)
369+
l = jnp.exp(lse - m)
370+
o = out.astype(jnp.float32) * l[..., None]
364371

365-
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
366-
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
372+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
367373

368-
def ring_scan_body(carry, _):
369-
m, l, o, k_current, v_current = carry
370-
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
371-
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
374+
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
375+
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
372376

373-
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
377+
def ring_scan_body(carry, _):
378+
m, l, o, k_current, v_current = carry
379+
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
380+
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
374381

375-
m_chunk = lse_chunk.astype(jnp.float32)
376-
m_old = m
377-
m = jnp.maximum(m_old, m_chunk)
382+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
378383

379-
exp_m_diff = jnp.exp(m_old - m)
380-
exp_m_chunk_diff = jnp.exp(m_chunk - m)
384+
m_chunk = lse_chunk.astype(jnp.float32)
385+
m_old = m
386+
m = jnp.maximum(m_old, m_chunk)
381387

382-
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
383-
o = o * exp_m_diff[..., None]
384-
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
388+
exp_m_diff = jnp.exp(m_old - m)
389+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
385390

386-
# Return the updated state for the next iteration
387-
return (m, l, o, k_next, v_next), None
391+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
392+
o = o * exp_m_diff[..., None]
393+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
388394

389-
initial_carry = (m, l, o, k1, v1)
390-
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
395+
# Return the updated state for the next iteration
396+
return (m, l, o, k_next, v_next), None
391397

392-
attention_output = o_final / l_final[..., None]
393-
else:
394-
raise ValueError("ring attention requires fsdp > 1")
398+
initial_carry = (m, l, o, k1, v1)
399+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
400+
401+
attention_output = o_final / l_final[..., None]
402+
else:
403+
raise ValueError("ring attention requires fsdp > 1")
395404

396405
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
397406

@@ -571,6 +580,7 @@ def _apply_attention(
571580
return _tpu_flash_attention(
572581
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
573582
mask_padding_tokens=mask_padding_tokens,
583+
residual_checkpoint_name=residual_checkpoint_name,
574584
)
575585
elif attention_kernel == "cudnn_flash_te":
576586
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
@@ -862,7 +872,8 @@ def __init__(
862872
else:
863873
axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV)
864874
axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV)
865-
875+
if attention_kernel == "tokamax_ring" and not is_self_attention:
876+
attention_kernel = "tokamax_flash" # do not use ring attention for cross attention
866877
self.attention_op = NNXAttentionOp(
867878
mesh=mesh,
868879
attention_kernel=attention_kernel,

0 commit comments

Comments
 (0)