Skip to content

Commit daf4a31

Browse files
committed
Integrate torchax custom attention kernel into ulysses
1 parent c98002f commit daf4a31

3 files changed

Lines changed: 800 additions & 54 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 101 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from einops import rearrange
3232
from .. import common_types, max_logging
3333

34+
from . import custom_splash_attention as custom_splash
3435
from . import quantizations
3536
from .modeling_flax_utils import get_activation
3637

@@ -521,6 +522,7 @@ def _ulysses_attention(
521522
mask_padding_tokens: bool = True,
522523
residual_checkpoint_name: str | None = None,
523524
attention_mask: jax.Array = None,
525+
use_custom_kernel: bool = False,
524526
) -> jax.Array:
525527
"""Ulysses sequence-parallel attention.
526528
@@ -544,7 +546,9 @@ def _ulysses_attention(
544546
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
545547
f"got heads={num_heads} and context_shards={num_shards}."
546548
)
547-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
549+
550+
if not use_custom_kernel:
551+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
548552

549553
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
550554
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
@@ -563,65 +567,93 @@ def wrap_ulysses_attention(query, key, value):
563567
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
564568
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
565569

566-
# Run the same local splash kernel as standard TPU flash attention, but now
567-
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
568-
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
569-
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
570-
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
571-
if uses_fused_kernel:
572-
block_q_sizes += (block_sizes.block_q_dkv,)
573-
block_kv_sizes += (block_sizes.block_kv_dkv,)
574-
else:
575-
block_q_sizes += (block_sizes.block_q_dq,)
576-
block_kv_sizes += (block_sizes.block_kv_dq,)
570+
if use_custom_kernel:
571+
bq = 4864
572+
bkv = 1024
573+
bkv_compute = 1024
574+
bkv_compute_in = 1024
575+
heads_per_tile = 1
577576

578-
block_q = max(*block_q_sizes)
579-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
580-
block_kv = max(*block_kv_sizes)
581-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
582-
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
577+
query_scaled = query * 1.44269504
583578

584-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
585-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
579+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
580+
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
581+
value, _, _ = _pad_data_for_flash(value, heads, bkv)
586582

587-
q_padded_len = query.shape[2]
588-
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
589-
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
583+
bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute)
590584

591-
kv_padded_len = key.shape[2]
592-
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
593-
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
585+
splash_kernel = custom_splash.make_splash_mha(
586+
block_sizes=bsizes,
587+
bkv_compute_in=bkv_compute_in,
588+
orig_q_seq_len=query_seq_len,
589+
orig_kv_seq_len=key_seq_len,
590+
heads_per_tile=heads_per_tile,
591+
)
594592

595-
# Reuse the standard flash-attention masking convention by zeroing invalid
596-
# KV positions in the segment ids passed down to splash.
597-
if attention_mask is not None:
598-
mask_len = min(key_seq_len, attention_mask.shape[1])
599-
kv_mask_for_batch = attention_mask[0, :mask_len]
600-
if key_seq_len > mask_len:
601-
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
602-
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
603-
if kv_padded_len > key_seq_len:
604-
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
605-
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
593+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
594+
attention_output = vmapped_splash(query_scaled, key, value)
595+
attention_output = jnp.swapaxes(attention_output, 2, 3)
596+
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
597+
else:
598+
# Run the same local splash kernel as standard TPU flash attention, but now
599+
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
600+
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
601+
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
602+
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
603+
if uses_fused_kernel:
604+
block_q_sizes += (block_sizes.block_q_dkv,)
605+
block_kv_sizes += (block_sizes.block_kv_dkv,)
606606
else:
607-
kv_mask_padded = kv_mask_for_batch
608-
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
607+
block_q_sizes += (block_sizes.block_q_dq,)
608+
block_kv_sizes += (block_sizes.block_kv_dq,)
609+
610+
block_q = max(*block_q_sizes)
611+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
612+
block_kv = max(*block_kv_sizes)
613+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
614+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
615+
616+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
617+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
618+
619+
q_padded_len = query.shape[2]
620+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
621+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
622+
623+
kv_padded_len = key.shape[2]
624+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
625+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
626+
627+
# Reuse the standard flash-attention masking convention by zeroing invalid
628+
# KV positions in the segment ids passed down to splash.
629+
if attention_mask is not None:
630+
mask_len = min(key_seq_len, attention_mask.shape[1])
631+
kv_mask_for_batch = attention_mask[0, :mask_len]
632+
if key_seq_len > mask_len:
633+
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
634+
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
635+
if kv_padded_len > key_seq_len:
636+
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
637+
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
638+
else:
639+
kv_mask_padded = kv_mask_for_batch
640+
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
609641

610-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
611-
if not mask_padding_tokens:
612-
segment_ids = None
642+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
643+
if not mask_padding_tokens:
644+
segment_ids = None
613645

614-
splash_kernel = splash_attention_kernel.make_splash_mha(
615-
mask=multi_head_mask,
616-
head_shards=1,
617-
q_seq_shards=1,
618-
block_sizes=block_sizes,
619-
save_residuals=False,
620-
residual_checkpoint_name=residual_checkpoint_name,
621-
)
622-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
623-
attention_output = vmapped_splash(query, key, value, segment_ids)
624-
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
646+
splash_kernel = splash_attention_kernel.make_splash_mha(
647+
mask=multi_head_mask,
648+
head_shards=1,
649+
q_seq_shards=1,
650+
block_sizes=block_sizes,
651+
save_residuals=False,
652+
residual_checkpoint_name=residual_checkpoint_name,
653+
)
654+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
655+
attention_output = vmapped_splash(query, key, value, segment_ids)
656+
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
625657

626658
# Restore the original layout expected by the rest of the model:
627659
# head-sharded / full-sequence -> sequence-sharded / full-heads.
@@ -763,7 +795,7 @@ def _apply_attention(
763795
seq_len_idx = 1
764796
if query.ndim == 4:
765797
seq_len_idx = 2
766-
if attention_kernel in ["flash", "tokamax_flash", "ulysses"]:
798+
if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom"]:
767799
can_use_flash_attention = (
768800
query.shape[seq_len_idx] >= flash_min_seq_length
769801
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -775,6 +807,22 @@ def _apply_attention(
775807
return _apply_attention_dot(
776808
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
777809
)
810+
elif attention_kernel == "ulysses_custom":
811+
return _ulysses_attention(
812+
query,
813+
key * scale,
814+
value,
815+
heads,
816+
mesh,
817+
axis_names_q,
818+
axis_names_kv,
819+
flash_block_sizes,
820+
dtype,
821+
mask_padding_tokens=mask_padding_tokens,
822+
residual_checkpoint_name=residual_checkpoint_name,
823+
attention_mask=attention_mask,
824+
use_custom_kernel=True,
825+
)
778826
elif attention_kernel == "ulysses":
779827
return _ulysses_attention(
780828
query,

0 commit comments

Comments
 (0)