Skip to content

Commit 7c08042

Browse files
committed
integrate torchax custom attention kernel into ulysses
1 parent c98002f commit 7c08042

3 files changed

Lines changed: 788 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from einops import rearrange
3232
from .. import common_types, max_logging
3333

34+
from . import custom_splash_attention as custom_splash
3435
from . import quantizations
3536
from .modeling_flax_utils import get_activation
3637

@@ -641,6 +642,94 @@ def wrap_ulysses_attention(query, key, value):
641642
return x
642643

643644

645+
def _ulysses_custom_attention(
646+
query: jax.Array,
647+
key: jax.Array,
648+
value: jax.Array,
649+
heads: int,
650+
mesh: Mesh,
651+
axis_names_q: AxisNames,
652+
axis_names_kv: AxisNames,
653+
flash_block_sizes: BlockSizes,
654+
dtype: jnp.dtype = jnp.float32,
655+
mask_padding_tokens: bool = False,
656+
residual_checkpoint_name: str | None = None,
657+
attention_mask: jax.Array = None,
658+
) -> jax.Array:
659+
"""Ulysses sequence-parallel attention with custom fast kernel."""
660+
axis_name = "context"
661+
num_shards = mesh.shape[axis_name]
662+
663+
# Reshape to [b, h, s, d] and pad sequence for even context-axis splitting.
664+
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards)
665+
key, _ = _reshape_data_for_flash(key, heads, num_shards)
666+
value, _ = _reshape_data_for_flash(value, heads, num_shards)
667+
num_heads = query.shape[1]
668+
if num_heads % num_shards != 0:
669+
raise ValueError(
670+
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
671+
f"got heads={num_heads} and context_shards={num_shards}."
672+
)
673+
674+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
675+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
676+
677+
@functools.partial(
678+
jax.shard_map,
679+
mesh=mesh,
680+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
681+
out_specs=q_axis_names,
682+
check_vma=False,
683+
)
684+
def wrap_ulysses_attention(query, key, value):
685+
query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
686+
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
687+
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
688+
689+
bq = 2048
690+
bkv = 2048
691+
bkv_compute = 1024
692+
bkv_compute_in = 256
693+
heads_per_tile = 1
694+
695+
query_scaled = query * 1.44269504
696+
697+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
698+
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
699+
value, _, _ = _pad_data_for_flash(value, heads, bkv)
700+
701+
bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute)
702+
703+
splash_kernel = custom_splash.make_splash_mha(
704+
block_sizes=bsizes,
705+
bkv_compute_in=bkv_compute_in,
706+
orig_q_seq_len=query_seq_len,
707+
orig_kv_seq_len=key_seq_len,
708+
heads_per_tile=heads_per_tile,
709+
)
710+
711+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
712+
attention_output = vmapped_splash(query_scaled, key, value)
713+
attention_output = jnp.swapaxes(attention_output, 2, 3)
714+
715+
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
716+
717+
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
718+
return attention_output
719+
720+
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
721+
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
722+
max_logging.log(
723+
"Warning, batch dimension should be shardable among the devices in data and fsdp"
724+
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
725+
)
726+
x = wrap_ulysses_attention(query, key, value)
727+
x = x[:, :, :orig_q_seq_len, :]
728+
x = _reshape_heads_to_head_dim(x)
729+
730+
return x
731+
732+
644733
def _apply_attention_dot(
645734
query: Array,
646735
key: Array,
@@ -763,7 +852,7 @@ def _apply_attention(
763852
seq_len_idx = 1
764853
if query.ndim == 4:
765854
seq_len_idx = 2
766-
if attention_kernel in ["flash", "tokamax_flash", "ulysses"]:
855+
if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom"]:
767856
can_use_flash_attention = (
768857
query.shape[seq_len_idx] >= flash_min_seq_length
769858
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -775,6 +864,21 @@ def _apply_attention(
775864
return _apply_attention_dot(
776865
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
777866
)
867+
elif attention_kernel == "ulysses_custom":
868+
return _ulysses_custom_attention(
869+
query,
870+
key * scale,
871+
value,
872+
heads,
873+
mesh,
874+
axis_names_q,
875+
axis_names_kv,
876+
flash_block_sizes,
877+
dtype,
878+
mask_padding_tokens=mask_padding_tokens,
879+
residual_checkpoint_name=residual_checkpoint_name,
880+
attention_mask=attention_mask,
881+
)
778882
elif attention_kernel == "ulysses":
779883
return _ulysses_attention(
780884
query,

0 commit comments

Comments
 (0)