Skip to content

Commit 10451b0

Browse files
committed
Refine ulysses attention and tests
1 parent bf5bc4e commit 10451b0

2 files changed

Lines changed: 351 additions & 59 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 65 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,38 @@ def convert_to_tokamax_splash_config(
227227
)
228228

229229

230+
def _resolve_tpu_attention_block_sizes(
231+
query_seq_len: int,
232+
kv_seq_len: int,
233+
flash_block_sizes: BlockSizes,
234+
dtype: jnp.dtype,
235+
attention_kernel: str = "flash",
236+
) -> BlockSizes:
237+
"""Resolve TPU splash attention block sizes for self- and cross-attention."""
238+
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
239+
is_cross_attention = kv_seq_len != query_seq_len
240+
if is_cross_attention:
241+
kv_max_block_size = ((kv_seq_len + 127) // 128) * 128
242+
else:
243+
kv_max_block_size = q_max_block_size
244+
245+
if flash_block_sizes and not is_cross_attention:
246+
return flash_block_sizes
247+
248+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
249+
return splash_attention_kernel.BlockSizes(
250+
block_q=block_size_q,
251+
block_kv_compute=min(kv_max_block_size, kv_seq_len),
252+
block_kv=min(kv_max_block_size, kv_seq_len),
253+
block_q_dkv=block_size_q,
254+
block_kv_dkv=min(kv_max_block_size, kv_seq_len),
255+
block_kv_dkv_compute=min(kv_max_block_size, query_seq_len),
256+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
257+
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query_seq_len),
258+
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
259+
)
260+
261+
230262
def _tpu_flash_attention(
231263
query: jax.Array,
232264
key: jax.Array,
@@ -244,32 +276,17 @@ def _tpu_flash_attention(
244276
) -> jax.Array:
245277
"""TPU Flash Attention"""
246278

247-
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
248-
# This is the case for cross-attn.
249-
if key.shape[1] != query.shape[1]:
250-
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
251-
else:
252-
kv_max_block_size = q_max_block_size
253-
# ensure that for cross attention we override the block sizes.
254-
if flash_block_sizes and key.shape[1] == query.shape[1]:
255-
block_sizes = flash_block_sizes
256-
else:
257-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
258-
block_sizes = splash_attention_kernel.BlockSizes(
259-
block_q=block_size_q,
260-
block_kv_compute=min(kv_max_block_size, key.shape[2]),
261-
block_kv=min(kv_max_block_size, key.shape[2]),
262-
block_q_dkv=block_size_q,
263-
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
264-
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
265-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
266-
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
267-
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
268-
)
269279
num_context_shards = mesh.shape["context"]
270280
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
271281
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
272282
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)
283+
block_sizes = _resolve_tpu_attention_block_sizes(
284+
query_seq_len=query.shape[2],
285+
kv_seq_len=key.shape[2],
286+
flash_block_sizes=flash_block_sizes,
287+
dtype=dtype,
288+
attention_kernel=attention_kernel,
289+
)
273290

274291
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
275292
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
@@ -425,6 +442,7 @@ def ring_scan_body(carry, _):
425442
# Ulysses sequence-parallel attention
426443
# ---------------------------------------------------------------------------
427444

445+
428446
def _ulysses_attention(
429447
query: jax.Array,
430448
key: jax.Array,
@@ -456,53 +474,41 @@ def _ulysses_attention(
456474
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards)
457475
key, _ = _reshape_data_for_flash(key, heads, num_shards)
458476
value, _ = _reshape_data_for_flash(value, heads, num_shards)
477+
num_heads = query.shape[1]
478+
# Ulysses only redistributes existing heads across the context mesh; unlike
479+
# the earlier draft, we fail fast instead of padding synthetic heads.
480+
if num_heads % num_shards != 0:
481+
raise ValueError(
482+
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
483+
f"got heads={num_heads} and context_shards={num_shards}."
484+
)
459485

460486
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
461487
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
462488

463-
# Pre-compute block sizes outside shard_map (uses global shapes).
464-
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
465-
if key.shape[2] != query.shape[2]:
466-
kv_max_block_size = ((key.shape[2] + 127) // 128) * 128
467-
else:
468-
kv_max_block_size = q_max_block_size
469-
if flash_block_sizes and key.shape[2] == query.shape[2]:
470-
block_sizes = flash_block_sizes
471-
else:
472-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
473-
block_sizes = splash_attention_kernel.BlockSizes(
474-
block_q=block_size_q,
475-
block_kv_compute=min(kv_max_block_size, key.shape[2]),
476-
block_kv=min(kv_max_block_size, key.shape[2]),
477-
block_q_dkv=block_size_q,
478-
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
479-
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
480-
block_q_dq=block_size_q,
481-
block_kv_dq=min(kv_max_block_size, query.shape[2]),
482-
use_fused_bwd_kernel=False,
483-
)
489+
block_sizes = _resolve_tpu_attention_block_sizes(
490+
query_seq_len=query.shape[2],
491+
kv_seq_len=key.shape[2],
492+
flash_block_sizes=flash_block_sizes,
493+
dtype=dtype,
494+
)
484495

485496
@functools.partial(
486-
shard_map.shard_map,
497+
jax.shard_map,
487498
mesh=mesh,
488499
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
489500
out_specs=q_axis_names,
490-
check_rep=False,
501+
check_vma=False,
491502
)
492503
def wrap_ulysses_attention(query, key, value):
493-
# --- Step 1: all-to-all sequence-sharded -> head-sharded ---
494-
original_q_heads = query.shape[1]
495-
head_pad = (-original_q_heads) % num_shards
496-
if head_pad:
497-
query = jnp.pad(query, ((0, 0), (0, head_pad), (0, 0), (0, 0)))
498-
key = jnp.pad(key, ((0, 0), (0, head_pad), (0, 0), (0, 0)))
499-
value = jnp.pad(value, ((0, 0), (0, head_pad), (0, 0), (0, 0)))
500-
504+
# Swap sharding modes: each device gives up a slice of sequence and gathers
505+
# a slice of heads, so the local splash kernel sees the full sequence.
501506
query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
502507
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
503508
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
504509

505-
# --- Step 2: local flash attention (full sequence, subset of heads) ---
510+
# Run the same local splash kernel as standard TPU flash attention, but now
511+
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
506512
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
507513
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
508514
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
@@ -530,6 +536,8 @@ def wrap_ulysses_attention(query, key, value):
530536
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
531537
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
532538

539+
# Reuse the standard flash-attention masking convention by zeroing invalid
540+
# KV positions in the segment ids passed down to splash.
533541
if attention_mask is not None:
534542
mask_len = min(key_seq_len, attention_mask.shape[1])
535543
kv_mask_for_batch = attention_mask[0, :mask_len]
@@ -559,11 +567,9 @@ def wrap_ulysses_attention(query, key, value):
559567
attention_output = vmapped_splash(query, key, value, segment_ids)
560568
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
561569

562-
# --- Step 3: all-to-all head-sharded -> sequence-sharded ---
563-
attention_output = jax.lax.all_to_all(
564-
attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True
565-
)
566-
attention_output = attention_output[:, :original_q_heads, :, :]
570+
# Restore the original layout expected by the rest of the model:
571+
# head-sharded / full-sequence -> sequence-sharded / full-heads.
572+
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
567573
return attention_output
568574

569575
devices_in_data_context = mesh.shape["data"] * num_shards

0 commit comments

Comments
 (0)