Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 101 additions & 53 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from einops import rearrange
from .. import common_types, max_logging

from . import custom_splash_attention as custom_splash
from . import quantizations
from .modeling_flax_utils import get_activation

Expand Down Expand Up @@ -521,6 +522,7 @@ def _ulysses_attention(
mask_padding_tokens: bool = True,
residual_checkpoint_name: str | None = None,
attention_mask: jax.Array = None,
use_custom_kernel: bool = False,
) -> jax.Array:
"""Ulysses sequence-parallel attention.

Expand All @@ -544,7 +546,9 @@ def _ulysses_attention(
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
f"got heads={num_heads} and context_shards={num_shards}."
)
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")

if not use_custom_kernel:
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")

q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
Expand All @@ -563,65 +567,93 @@ def wrap_ulysses_attention(query, key, value):
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)

# Run the same local splash kernel as standard TPU flash attention, but now
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
if uses_fused_kernel:
block_q_sizes += (block_sizes.block_q_dkv,)
block_kv_sizes += (block_sizes.block_kv_dkv,)
else:
block_q_sizes += (block_sizes.block_q_dq,)
block_kv_sizes += (block_sizes.block_kv_dq,)
if use_custom_kernel:
bq = 4864
bkv = 1024
bkv_compute = 1024
bkv_compute_in = 1024
heads_per_tile = 1
Comment on lines +571 to +575
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to modify flash_block_sizes config so that users can set these values?


block_q = max(*block_q_sizes)
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
block_kv = max(*block_kv_sizes)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
query_scaled = query * 1.44269504
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we scale it inside the custom kernel? This scale is only applicable if use base2_exp


mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
value, _, _ = _pad_data_for_flash(value, heads, bkv)

q_padded_len = query.shape[2]
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute)

kv_padded_len = key.shape[2]
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
splash_kernel = custom_splash.make_splash_mha(
block_sizes=bsizes,
bkv_compute_in=bkv_compute_in,
orig_q_seq_len=query_seq_len,
orig_kv_seq_len=key_seq_len,
heads_per_tile=heads_per_tile,
)

# Reuse the standard flash-attention masking convention by zeroing invalid
# KV positions in the segment ids passed down to splash.
if attention_mask is not None:
mask_len = min(key_seq_len, attention_mask.shape[1])
kv_mask_for_batch = attention_mask[0, :mask_len]
if key_seq_len > mask_len:
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
if kv_padded_len > key_seq_len:
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
attention_output = vmapped_splash(query_scaled, key, value)
attention_output = jnp.swapaxes(attention_output, 2, 3)
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
else:
# Run the same local splash kernel as standard TPU flash attention, but now
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
if uses_fused_kernel:
block_q_sizes += (block_sizes.block_q_dkv,)
block_kv_sizes += (block_sizes.block_kv_dkv,)
else:
kv_mask_padded = kv_mask_for_batch
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
block_q_sizes += (block_sizes.block_q_dq,)
block_kv_sizes += (block_sizes.block_kv_dq,)

block_q = max(*block_q_sizes)
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
block_kv = max(*block_kv_sizes)
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
value, _, _ = _pad_data_for_flash(value, heads, block_kv)

mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])

q_padded_len = query.shape[2]
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)

kv_padded_len = key.shape[2]
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)

# Reuse the standard flash-attention masking convention by zeroing invalid
# KV positions in the segment ids passed down to splash.
if attention_mask is not None:
mask_len = min(key_seq_len, attention_mask.shape[1])
kv_mask_for_batch = attention_mask[0, :mask_len]
if key_seq_len > mask_len:
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
if kv_padded_len > key_seq_len:
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
else:
kv_mask_padded = kv_mask_for_batch
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)

segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
if not mask_padding_tokens:
segment_ids = None
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
if not mask_padding_tokens:
segment_ids = None

splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1,
q_seq_shards=1,
block_sizes=block_sizes,
save_residuals=False,
residual_checkpoint_name=residual_checkpoint_name,
)
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
attention_output = vmapped_splash(query, key, value, segment_ids)
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
head_shards=1,
q_seq_shards=1,
block_sizes=block_sizes,
save_residuals=False,
residual_checkpoint_name=residual_checkpoint_name,
)
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
attention_output = vmapped_splash(query, key, value, segment_ids)
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)

# Restore the original layout expected by the rest of the model:
# head-sharded / full-sequence -> sequence-sharded / full-heads.
Expand Down Expand Up @@ -763,7 +795,7 @@ def _apply_attention(
seq_len_idx = 1
if query.ndim == 4:
seq_len_idx = 2
if attention_kernel in ["flash", "tokamax_flash", "ulysses"]:
if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom"]:
can_use_flash_attention = (
query.shape[seq_len_idx] >= flash_min_seq_length
and key.shape[seq_len_idx] >= flash_min_seq_length
Expand All @@ -775,6 +807,22 @@ def _apply_attention(
return _apply_attention_dot(
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
)
elif attention_kernel == "ulysses_custom":
return _ulysses_attention(
query,
key * scale,
value,
heads,
mesh,
axis_names_q,
axis_names_kv,
flash_block_sizes,
dtype,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name=residual_checkpoint_name,
attention_mask=attention_mask,
use_custom_kernel=True,
)
elif attention_kernel == "ulysses":
return _ulysses_attention(
query,
Expand Down
Loading
Loading