Skip to content

Commit 7dfc26e

Browse files
committed
Add tokamax_ulysses
1 parent ae22683 commit 7dfc26e

3 files changed

Lines changed: 102 additions & 12 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def _ulysses_attention(
507507
mask_padding_tokens: bool = True,
508508
residual_checkpoint_name: str | None = None,
509509
attention_mask: jax.Array = None,
510+
attention_kernel: str = "ulysses",
510511
) -> jax.Array:
511512
"""Ulysses sequence-parallel attention.
512513
@@ -530,7 +531,9 @@ def _ulysses_attention(
530531
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
531532
f"got heads={num_heads} and context_shards={num_shards}."
532533
)
533-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
534+
535+
inner_kernel = "tokamax_flash" if attention_kernel == "tokamax_ulysses" else "flash"
536+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, inner_kernel)
534537

535538
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
536539
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
@@ -597,14 +600,26 @@ def wrap_ulysses_attention(query, key, value):
597600
if not mask_padding_tokens:
598601
segment_ids = None
599602

600-
splash_kernel = splash_attention_kernel.make_splash_mha(
601-
mask=multi_head_mask,
602-
head_shards=1,
603-
q_seq_shards=1,
604-
block_sizes=block_sizes,
605-
save_residuals=False,
606-
residual_checkpoint_name=residual_checkpoint_name,
607-
)
603+
if attention_kernel == "tokamax_ulysses":
604+
mask = tokamax_splash_attention_mask.FullMask(
605+
_shape=(query.shape[2], key.shape[2]),
606+
)
607+
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
608+
mask=mask,
609+
q_seq_shards=1,
610+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
611+
save_residuals=False,
612+
)
613+
else:
614+
splash_kernel = splash_attention_kernel.make_splash_mha(
615+
mask=multi_head_mask,
616+
head_shards=1,
617+
q_seq_shards=1,
618+
block_sizes=block_sizes,
619+
save_residuals=False,
620+
residual_checkpoint_name=residual_checkpoint_name,
621+
)
622+
608623
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
609624
attention_output = vmapped_splash(query, key, value, segment_ids)
610625
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
@@ -747,7 +762,7 @@ def _apply_attention(
747762
seq_len_idx = 1
748763
if query.ndim == 4:
749764
seq_len_idx = 2
750-
if attention_kernel in ["flash", "tokamax_flash", "ulysses"]:
765+
if attention_kernel in ["flash", "tokamax_flash", "ulysses", "tokamax_ulysses"]:
751766
can_use_flash_attention = (
752767
query.shape[seq_len_idx] >= flash_min_seq_length
753768
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -759,7 +774,7 @@ def _apply_attention(
759774
return _apply_attention_dot(
760775
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
761776
)
762-
elif attention_kernel == "ulysses":
777+
elif attention_kernel in ["ulysses", "tokamax_ulysses"]:
763778
return _ulysses_attention(
764779
query,
765780
key * scale,
@@ -773,6 +788,7 @@ def _apply_attention(
773788
mask_padding_tokens=mask_padding_tokens,
774789
residual_checkpoint_name=residual_checkpoint_name,
775790
attention_mask=attention_mask,
791+
attention_kernel=attention_kernel,
776792
)
777793
elif attention_kernel in ["flash", "tokamax_flash"]:
778794
return _tpu_flash_attention(

src/maxdiffusion/pyconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def user_init(raw_keys):
214214
# Verify qkv is sharded across sequence.
215215
attention = raw_keys["attention"]
216216
uses_ring_attention = "ring" in attention
217-
uses_ulysses_attention = attention == "ulysses"
217+
uses_ulysses_attention = attention in ["ulysses", "tokamax_ulysses"]
218218
uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"]
219219
if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding:
220220
max_logging.log(

src/maxdiffusion/tests/attention_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,5 +442,79 @@ def fake_kernel(q, k, v, segment_ids):
442442
self.assertTrue(jnp.array_equal(output, expected))
443443

444444

445+
def test_tokamax_ulysses_attention_matches_tokamax_flash(self):
446+
"""Tokamax Flash and Tokamax Ulysses should agree when the local splash kernel is shared."""
447+
batch = 2
448+
length = 6
449+
heads = 4
450+
head_depth = 3
451+
query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth)
452+
key = query + 100.0
453+
value = query + 200.0
454+
mesh = self._ulysses_mesh()
455+
456+
def fake_make_splash_mha(**unused_kwargs):
457+
def fake_kernel(q, k, v, segment_ids):
458+
del k, segment_ids
459+
return q + jnp.mean(v, axis=1, keepdims=True)
460+
461+
return fake_kernel
462+
463+
with mock.patch.object(
464+
attention_flax.tokamax_splash_attention_kernel,
465+
"make_splash_mha",
466+
side_effect=fake_make_splash_mha,
467+
):
468+
with mesh, nn_partitioning.axis_rules(self._flash_axis_rules()):
469+
flash_output = attention_flax._tpu_flash_attention(
470+
query,
471+
key,
472+
value,
473+
heads=heads,
474+
mesh=mesh,
475+
axis_names_q=(
476+
attention_flax.BATCH,
477+
attention_flax.SELF_ATTN_HEAD,
478+
attention_flax.SELF_ATTN_Q_LENGTH,
479+
attention_flax.D_KV,
480+
),
481+
axis_names_kv=(
482+
attention_flax.BATCH,
483+
attention_flax.SELF_ATTN_HEAD,
484+
attention_flax.SELF_ATTN_KV_LENGTH,
485+
attention_flax.D_KV,
486+
),
487+
flash_block_sizes=self._ulysses_block_sizes(),
488+
dtype=jnp.float32,
489+
attention_kernel="tokamax_flash",
490+
)
491+
492+
with mesh, nn_partitioning.axis_rules(self._ulysses_axis_rules()):
493+
ulysses_output = attention_flax._ulysses_attention(
494+
query,
495+
key,
496+
value,
497+
heads=heads,
498+
mesh=mesh,
499+
axis_names_q=(
500+
attention_flax.BATCH,
501+
attention_flax.SELF_ATTN_HEAD,
502+
attention_flax.SELF_ATTN_Q_LENGTH,
503+
attention_flax.D_KV,
504+
),
505+
axis_names_kv=(
506+
attention_flax.BATCH,
507+
attention_flax.SELF_ATTN_HEAD,
508+
attention_flax.SELF_ATTN_KV_LENGTH,
509+
attention_flax.D_KV,
510+
),
511+
flash_block_sizes=self._ulysses_block_sizes(),
512+
dtype=jnp.float32,
513+
attention_kernel="tokamax_ulysses",
514+
)
515+
516+
self.assertEqual(flash_output.shape, ulysses_output.shape)
517+
self.assertTrue(jnp.array_equal(flash_output, ulysses_output))
518+
445519
if __name__ == "__main__":
446520
absltest.main()

0 commit comments

Comments
 (0)