Skip to content

Commit 56c76b8

Browse files
committed
resolving comments, tests pending
1 parent 7c08042 commit 56c76b8

2 files changed

Lines changed: 101 additions & 141 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 85 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def _ulysses_attention(
522522
mask_padding_tokens: bool = True,
523523
residual_checkpoint_name: str | None = None,
524524
attention_mask: jax.Array = None,
525+
use_custom_kernel: bool = False,
525526
) -> jax.Array:
526527
"""Ulysses sequence-parallel attention.
527528
@@ -545,7 +546,9 @@ def _ulysses_attention(
545546
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
546547
f"got heads={num_heads} and context_shards={num_shards}."
547548
)
548-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
549+
550+
if not use_custom_kernel:
551+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
549552

550553
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
551554
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
@@ -564,65 +567,93 @@ def wrap_ulysses_attention(query, key, value):
564567
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
565568
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
566569

567-
# Run the same local splash kernel as standard TPU flash attention, but now
568-
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
569-
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
570-
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
571-
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
572-
if uses_fused_kernel:
573-
block_q_sizes += (block_sizes.block_q_dkv,)
574-
block_kv_sizes += (block_sizes.block_kv_dkv,)
575-
else:
576-
block_q_sizes += (block_sizes.block_q_dq,)
577-
block_kv_sizes += (block_sizes.block_kv_dq,)
570+
if use_custom_kernel:
571+
bq = 4864
572+
bkv = 1024
573+
bkv_compute = 1024
574+
bkv_compute_in = 1024
575+
heads_per_tile = 1
578576

579-
block_q = max(*block_q_sizes)
580-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
581-
block_kv = max(*block_kv_sizes)
582-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
583-
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
577+
query_scaled = query * 1.44269504
584578

585-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
586-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
579+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
580+
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
581+
value, _, _ = _pad_data_for_flash(value, heads, bkv)
587582

588-
q_padded_len = query.shape[2]
589-
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
590-
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
583+
bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute)
591584

592-
kv_padded_len = key.shape[2]
593-
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
594-
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
585+
splash_kernel = custom_splash.make_splash_mha(
586+
block_sizes=bsizes,
587+
bkv_compute_in=bkv_compute_in,
588+
orig_q_seq_len=query_seq_len,
589+
orig_kv_seq_len=key_seq_len,
590+
heads_per_tile=heads_per_tile,
591+
)
595592

596-
# Reuse the standard flash-attention masking convention by zeroing invalid
597-
# KV positions in the segment ids passed down to splash.
598-
if attention_mask is not None:
599-
mask_len = min(key_seq_len, attention_mask.shape[1])
600-
kv_mask_for_batch = attention_mask[0, :mask_len]
601-
if key_seq_len > mask_len:
602-
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
603-
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
604-
if kv_padded_len > key_seq_len:
605-
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
606-
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
593+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
594+
attention_output = vmapped_splash(query_scaled, key, value)
595+
attention_output = jnp.swapaxes(attention_output, 2, 3)
596+
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
597+
else:
598+
# Run the same local splash kernel as standard TPU flash attention, but now
599+
# on full-sequence / fewer-heads tensors produced by the all-to-all above.
600+
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
601+
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv)
602+
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv)
603+
if uses_fused_kernel:
604+
block_q_sizes += (block_sizes.block_q_dkv,)
605+
block_kv_sizes += (block_sizes.block_kv_dkv,)
607606
else:
608-
kv_mask_padded = kv_mask_for_batch
609-
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
607+
block_q_sizes += (block_sizes.block_q_dq,)
608+
block_kv_sizes += (block_sizes.block_kv_dq,)
609+
610+
block_q = max(*block_q_sizes)
611+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
612+
block_kv = max(*block_kv_sizes)
613+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
614+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
615+
616+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
617+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
618+
619+
q_padded_len = query.shape[2]
620+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
621+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
622+
623+
kv_padded_len = key.shape[2]
624+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
625+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
626+
627+
# Reuse the standard flash-attention masking convention by zeroing invalid
628+
# KV positions in the segment ids passed down to splash.
629+
if attention_mask is not None:
630+
mask_len = min(key_seq_len, attention_mask.shape[1])
631+
kv_mask_for_batch = attention_mask[0, :mask_len]
632+
if key_seq_len > mask_len:
633+
extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32)
634+
kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0)
635+
if kv_padded_len > key_seq_len:
636+
padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32)
637+
kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0)
638+
else:
639+
kv_mask_padded = kv_mask_for_batch
640+
kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32)
610641

611-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
612-
if not mask_padding_tokens:
613-
segment_ids = None
642+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
643+
if not mask_padding_tokens:
644+
segment_ids = None
614645

615-
splash_kernel = splash_attention_kernel.make_splash_mha(
616-
mask=multi_head_mask,
617-
head_shards=1,
618-
q_seq_shards=1,
619-
block_sizes=block_sizes,
620-
save_residuals=False,
621-
residual_checkpoint_name=residual_checkpoint_name,
622-
)
623-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
624-
attention_output = vmapped_splash(query, key, value, segment_ids)
625-
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
646+
splash_kernel = splash_attention_kernel.make_splash_mha(
647+
mask=multi_head_mask,
648+
head_shards=1,
649+
q_seq_shards=1,
650+
block_sizes=block_sizes,
651+
save_residuals=False,
652+
residual_checkpoint_name=residual_checkpoint_name,
653+
)
654+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
655+
attention_output = vmapped_splash(query, key, value, segment_ids)
656+
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
626657

627658
# Restore the original layout expected by the rest of the model:
628659
# head-sharded / full-sequence -> sequence-sharded / full-heads.
@@ -642,94 +673,6 @@ def wrap_ulysses_attention(query, key, value):
642673
return x
643674

644675

645-
def _ulysses_custom_attention(
646-
query: jax.Array,
647-
key: jax.Array,
648-
value: jax.Array,
649-
heads: int,
650-
mesh: Mesh,
651-
axis_names_q: AxisNames,
652-
axis_names_kv: AxisNames,
653-
flash_block_sizes: BlockSizes,
654-
dtype: jnp.dtype = jnp.float32,
655-
mask_padding_tokens: bool = False,
656-
residual_checkpoint_name: str | None = None,
657-
attention_mask: jax.Array = None,
658-
) -> jax.Array:
659-
"""Ulysses sequence-parallel attention with custom fast kernel."""
660-
axis_name = "context"
661-
num_shards = mesh.shape[axis_name]
662-
663-
# Reshape to [b, h, s, d] and pad sequence for even context-axis splitting.
664-
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards)
665-
key, _ = _reshape_data_for_flash(key, heads, num_shards)
666-
value, _ = _reshape_data_for_flash(value, heads, num_shards)
667-
num_heads = query.shape[1]
668-
if num_heads % num_shards != 0:
669-
raise ValueError(
670-
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
671-
f"got heads={num_heads} and context_shards={num_shards}."
672-
)
673-
674-
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
675-
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
676-
677-
@functools.partial(
678-
jax.shard_map,
679-
mesh=mesh,
680-
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
681-
out_specs=q_axis_names,
682-
check_vma=False,
683-
)
684-
def wrap_ulysses_attention(query, key, value):
685-
query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
686-
key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
687-
value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True)
688-
689-
bq = 2048
690-
bkv = 2048
691-
bkv_compute = 1024
692-
bkv_compute_in = 256
693-
heads_per_tile = 1
694-
695-
query_scaled = query * 1.44269504
696-
697-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
698-
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
699-
value, _, _ = _pad_data_for_flash(value, heads, bkv)
700-
701-
bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute)
702-
703-
splash_kernel = custom_splash.make_splash_mha(
704-
block_sizes=bsizes,
705-
bkv_compute_in=bkv_compute_in,
706-
orig_q_seq_len=query_seq_len,
707-
orig_kv_seq_len=key_seq_len,
708-
heads_per_tile=heads_per_tile,
709-
)
710-
711-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
712-
attention_output = vmapped_splash(query_scaled, key, value)
713-
attention_output = jnp.swapaxes(attention_output, 2, 3)
714-
715-
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
716-
717-
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
718-
return attention_output
719-
720-
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
721-
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
722-
max_logging.log(
723-
"Warning, batch dimension should be shardable among the devices in data and fsdp"
724-
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
725-
)
726-
x = wrap_ulysses_attention(query, key, value)
727-
x = x[:, :, :orig_q_seq_len, :]
728-
x = _reshape_heads_to_head_dim(x)
729-
730-
return x
731-
732-
733676
def _apply_attention_dot(
734677
query: Array,
735678
key: Array,
@@ -865,7 +808,7 @@ def _apply_attention(
865808
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
866809
)
867810
elif attention_kernel == "ulysses_custom":
868-
return _ulysses_custom_attention(
811+
return _ulysses_attention(
869812
query,
870813
key * scale,
871814
value,
@@ -878,6 +821,7 @@ def _apply_attention(
878821
mask_padding_tokens=mask_padding_tokens,
879822
residual_checkpoint_name=residual_checkpoint_name,
880823
attention_mask=attention_mask,
824+
use_custom_kernel=True,
881825
)
882826
elif attention_kernel == "ulysses":
883827
return _ulysses_attention(

src/maxdiffusion/models/custom_splash_attention.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
117
"""Custom Pallas flash attention kernel for TPU."""
218

319
import functools

0 commit comments

Comments
 (0)