Skip to content

Commit f227fa2

Browse files
Merge pull request #3087 from AI-Hypercomputer:v32_flash_integration
PiperOrigin-RevId: 868271717
2 parents bc0eca3 + b9457e4 commit f227fa2

5 files changed

Lines changed: 113 additions & 33 deletions

File tree

src/MaxText/layers/attention_mla.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,8 +1025,6 @@ def __call__(
10251025
inputs_positions=inputs_positions,
10261026
attention_mask=attention_mask,
10271027
)
1028-
if index_mask is not None:
1029-
index_mask = index_mask[:, None, None, :, :] # [b, 1, 1, q_len, kv_len]
10301028

10311029
if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN:
10321030
unnormalized_out, _, exp_sum = self.ds_paged_attention_op(

src/MaxText/layers/attention_op.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,9 @@ def apply_attention(
881881
Use `dot_product` instead."""
882882
)
883883
return (
884-
self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks),
884+
self.tpu_flash_attention(
885+
query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks, index_mask
886+
),
885887
None,
886888
None,
887889
)
@@ -1038,6 +1040,7 @@ def tpu_flash_attention(
10381040
decoder_segment_ids: Array | None,
10391041
attn_logits_soft_cap: float | None = None,
10401042
sinks: Array | None = None,
1043+
index_mask: Array | None = None,
10411044
) -> Array:
10421045
"""TPU Flash Attention."""
10431046

@@ -1063,10 +1066,12 @@ def tpu_flash_attention(
10631066
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep)
10641067
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep)
10651068
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep)
1069+
index_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH))
10661070
else:
10671071
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
10681072
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
10691073
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
1074+
index_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH))
10701075

10711076
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
10721077
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
@@ -1253,10 +1258,12 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
12531258
axis_names_kv,
12541259
segment_axis_names_q,
12551260
segment_axis_names_kv,
1261+
None, # no sharding for config
12561262
segment_axis_names_splash_kernel,
12571263
None, # no sharding for cp_size
12581264
None, # no sharding for load_balanced_context_parallel
12591265
sink_axis_names, # sharding align with query heads
1266+
index_mask_axis_names,
12601267
),
12611268
out_specs=axis_names_q,
12621269
check_vma=False,
@@ -1267,10 +1274,12 @@ def wrap_flash_attention(
12671274
value,
12681275
decoder_segment_ids_q,
12691276
decoder_segment_ids_kv,
1277+
sa_config,
12701278
splash_kernel,
12711279
cp_size,
12721280
load_balanced_context_parallel,
12731281
sinks,
1282+
index_mask,
12741283
):
12751284
# If load_balanced_context_parallel is enabled, reorder the key and value tensors
12761285
# to ensure that they are contiguous in memory.
@@ -1296,10 +1305,25 @@ def wrap_flash_attention(
12961305
decoder_segment_ids_tuple = None
12971306

12981307
if self.config.use_tokamax_splash:
1299-
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
1300-
attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
1301-
query, key, value, decoder_segment_ids_tuple, sinks
1302-
)
1308+
if self.config.use_sparse_indexer and index_mask is not None:
1309+
# Construct the splash kernel call with dynamic mask
1310+
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, index_mask):
1311+
splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(
1312+
mask=index_mask,
1313+
config=sa_config,
1314+
)
1315+
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
1316+
return kernel(q, k, v, segment, sinks=sinks)
1317+
1318+
# Iterate over batch dimension for (query, key, value, segment, sinks, mask)
1319+
attn_fn = jax.vmap(dynamic_mask_splash_kernel, (0, 0, 0, 0, None, 0))
1320+
index_mask = jnp.isclose(index_mask, 0.0)
1321+
attention_output = attn_fn(query, key, value, decoder_segment_ids_tuple, sinks, index_mask)
1322+
else:
1323+
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
1324+
attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
1325+
query, key, value, decoder_segment_ids_tuple, sinks
1326+
)
13031327
elif self.config.use_jax_splash:
13041328
materialized_mask = jnp.asarray(mask[:, :])
13051329
attention_output = jax_flash_attention.flash_attention_block_masked(
@@ -1337,17 +1361,20 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
13371361
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
13381362
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
13391363
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
1364+
index_mask = _maybe_shard_with_pspec(index_mask, index_mask_axis_names)
13401365

13411366
x = wrap_flash_attention(
13421367
query,
13431368
key,
13441369
value,
13451370
decoder_segment_ids_q,
13461371
decoder_segment_ids_kv,
1372+
sa_config,
13471373
None if self.config.use_jax_splash else splash_kernel,
13481374
cp_size,
13491375
load_balanced_context_parallel,
13501376
sinks,
1377+
index_mask,
13511378
)
13521379

13531380
x = jnp.transpose(x, axes=(0, 2, 1, 3))
@@ -1639,8 +1666,9 @@ def apply_attention_dot(
16391666
# Apply index mask, deepseek sparse attention
16401667
# index mask contains 0.0 for kept tokens and large negative for masked tokens.
16411668
if index_mask is not None:
1669+
# index_mask: from [b, q_len, kv_len] to [b, 1, 1, q_len, kv_len]
1670+
index_mask = index_mask[:, None, None, :, :]
16421671
# attn_weights: [b, n_kv, n_q // n_kv, q_len, kv_len]
1643-
# index_mask: [b, 1, 1, q_len, kv_len]
16441672
attn_weights = apply_mask_to_logits(attn_weights, index_mask)
16451673

16461674
if self.is_partition_in_decode(q_seq_len):

src/maxtext/configs/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,8 +2205,12 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22052205
if self.use_sparse_indexer:
22062206
if self.q_lora_rank == 0:
22072207
raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.")
2208-
if self.attention not in ("dot_product"):
2209-
raise ValueError("Sparse indexer is only supported dot_product attention")
2208+
supports_dot_product = self.attention == "dot_product"
2209+
supports_flash_splash = self.attention == "flash" and self.use_tokamax_splash
2210+
if not (supports_dot_product or supports_flash_splash):
2211+
raise NotImplementedError(
2212+
"Sparse indexer is only supported dot_product attention or flash attention with tokamax splash."
2213+
)
22102214
if self.attention_type == AttentionType.CHUNK.value and (
22112215
not isinstance(self.chunk_attn_window_size, int) or self.chunk_attn_window_size <= 0
22122216
):

tests/unit/deepseek32_vs_reference_test.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class Config:
8282
qk_nope_head_dim: int = 128
8383
qk_rope_head_dim: int = 64
8484
v_head_dim: int = 128
85+
use_tokamax_splash: bool = True
8586
# yarn
8687
rope_type: str = "yarn"
8788
original_max_position_embeddings: int = 4096
@@ -98,7 +99,6 @@ class Config:
9899
use_sparse_indexer: bool = True
99100
index_n_heads: int = 64
100101
index_head_dim: int = 128 # > qk_rope_head_dim
101-
index_topk: int = 4
102102

103103

104104
class ModelArgs:
@@ -107,7 +107,7 @@ class ModelArgs:
107107
Maps MaxText Config keys to the specific variable names expected by the reference implementation.
108108
"""
109109

110-
def __init__(self, config: Config, max_batch_size: int = 8):
110+
def __init__(self, config: Config, max_batch_size: int = 8, index_topk: int = 4):
111111
self.max_batch_size = max_batch_size
112112
self.scale_fmt = None
113113
self.max_seq_len = config.max_position_embeddings
@@ -119,6 +119,7 @@ def __init__(self, config: Config, max_batch_size: int = 8):
119119
self.qk_nope_head_dim = config.qk_nope_head_dim
120120
self.qk_rope_head_dim = config.qk_rope_head_dim
121121
self.v_head_dim = config.v_head_dim
122+
self.use_tokamax_splash = config.use_tokamax_splash
122123
# yarn
123124
self.original_seq_len = config.original_max_position_embeddings
124125
self.rope_theta = float(config.rope_max_timescale)
@@ -129,7 +130,7 @@ def __init__(self, config: Config, max_batch_size: int = 8):
129130
# indexer
130131
self.index_n_heads = config.index_n_heads
131132
self.index_head_dim = config.index_head_dim
132-
self.index_topk = config.index_topk
133+
self.index_topk = index_topk
133134

134135

135136
# -----------------------------------------------------------------------------
@@ -457,14 +458,14 @@ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
457458

458459
class Indexer(torch.nn.Module): # pylint: disable=missing-class-docstring
459460

460-
def __init__(self, args: ModelArgs):
461+
def __init__(self, args: ModelArgs, index_topk: int = 4):
461462
super().__init__()
462463
self.dim: int = args.dim
463464
self.n_heads: int = args.index_n_heads
464465
self.n_local_heads = args.index_n_heads // world_size
465466
self.head_dim: int = args.index_head_dim
466467
self.rope_head_dim: int = args.qk_rope_head_dim
467-
self.index_topk: int = args.index_topk
468+
self.index_topk: int = index_topk
468469
self.q_lora_rank: int = args.q_lora_rank
469470
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
470471
self.wk = Linear(self.dim, self.head_dim)
@@ -580,7 +581,7 @@ class MLA(nn.Module):
580581
softmax_scale (float): Scaling factor for softmax in attention computation.
581582
"""
582583

583-
def __init__(self, args: ModelArgs):
584+
def __init__(self, args: ModelArgs, index_topk: int):
584585
super().__init__()
585586
self.dim = args.dim
586587
self.n_heads = args.n_heads
@@ -605,7 +606,7 @@ def __init__(self, args: ModelArgs):
605606
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
606607
self.softmax_scale = self.softmax_scale * mscale * mscale
607608

608-
self.indexer = Indexer(args)
609+
self.indexer = Indexer(args, index_topk)
609610

610611
self.register_buffer(
611612
"kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False
@@ -750,7 +751,7 @@ def get_jax_mla_weights(pt_mla, cfg):
750751
}
751752

752753

753-
def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len):
754+
def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len, attention, index_topk):
754755
"""Returns MaxText configuration and mesh."""
755756
cfg = pyconfig.initialize(
756757
[None, get_test_config_path()],
@@ -766,7 +767,8 @@ def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len):
766767
per_device_batch_size=batch_size,
767768
max_target_length=seq_len,
768769
max_prefill_predict_length=seq_len,
769-
attention="dot_product",
770+
attention=attention,
771+
index_topk=index_topk,
770772
**asdict(config),
771773
)
772774
devices_array = maxtext_utils.create_device_mesh(cfg)
@@ -785,7 +787,7 @@ def setUp(self):
785787
np.random.seed(42)
786788

787789
self.dtype = "float32"
788-
self.batch_size = 2
790+
self.batch_size = 4
789791
self.start_pos = 0
790792
self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42))
791793
# jax config
@@ -861,6 +863,8 @@ def test_indexer_match(self, seq_len=8):
861863
dtype=self.dtype,
862864
batch_size=self.batch_size,
863865
seq_len=self.seq_len,
866+
attention="dot_product",
867+
index_topk=4,
864868
)
865869

866870
# Indexer specific RoPE (interleave=False)
@@ -906,17 +910,53 @@ class DeepseekV32MLATest(DeepseekTestBase):
906910
"""Tests for MLA Attention with Sparse Indexing."""
907911

908912
@parameterized.named_parameters(
909-
{"testcase_name": "seq_len=2 (index_topk=4)", "seq_len": 2},
910-
{"testcase_name": "seq_len=8 (index_topk=4)", "seq_len": 8},
913+
{
914+
"testcase_name": "dot_product_s2_k4",
915+
"attention": "dot_product",
916+
"seq_len": 2,
917+
"index_topk": 4,
918+
},
919+
{
920+
"testcase_name": "dot_product_s8_k4",
921+
"attention": "dot_product",
922+
"seq_len": 8,
923+
"index_topk": 4,
924+
},
925+
{
926+
"testcase_name": "dot_product_s128_k4",
927+
"attention": "dot_product",
928+
"seq_len": 128,
929+
"index_topk": 4,
930+
"check_norm": True,
931+
},
932+
{
933+
"testcase_name": "dot_product_s128_k128",
934+
"attention": "dot_product",
935+
"seq_len": 128,
936+
"index_topk": 128,
937+
"check_norm": True,
938+
},
939+
{
940+
"testcase_name": "flash_s128_k4",
941+
"attention": "flash",
942+
"seq_len": 128,
943+
"index_topk": 4,
944+
"check_norm": True,
945+
},
946+
{
947+
"testcase_name": "flash_s128_k128",
948+
"attention": "flash",
949+
"seq_len": 128,
950+
"index_topk": 128,
951+
"check_norm": True,
952+
},
911953
)
912-
# index_topk=4
913-
def test_mla_match(self, seq_len=8):
914-
"""Verifies MLA output (train mode) matches PyTorch (MHA mode) with indexer."""
915-
954+
def test_mla_parity(self, attention, seq_len, index_topk, check_norm=False):
955+
"""Verifies JAX MLA output against the PyTorch reference implementation."""
916956
torch_inputs, jax_inputs = self.get_data(seq_len)
917957

918958
# 1. PyTorch Run
919-
pt_mla = MLA(self.pt_args)
959+
pt_mla = MLA(self.pt_args, index_topk)
920960
init_torch_weights(pt_mla)
921961
pt_mla.eval()
922962

@@ -936,6 +976,8 @@ def test_mla_match(self, seq_len=8):
936976
dtype=self.dtype,
937977
batch_size=self.batch_size,
938978
seq_len=self.seq_len,
979+
attention=attention,
980+
index_topk=index_topk,
939981
)
940982

941983
jax_mla = attention_mla.MLA(
@@ -959,7 +1001,7 @@ def test_mla_match(self, seq_len=8):
9591001
rope_factor=cfg.rope_factor,
9601002
max_target_length=self.seq_len,
9611003
mesh=mesh,
962-
attention_kernel="dot_product",
1004+
attention_kernel=attention,
9631005
inputs_q_shape=(self.batch_size, self.seq_len, cfg.emb_dim),
9641006
inputs_kv_shape=(self.batch_size, self.seq_len, cfg.emb_dim),
9651007
rngs=self.nnx_rng,
@@ -976,10 +1018,17 @@ def test_mla_match(self, seq_len=8):
9761018
model_mode=MODEL_MODE_TRAIN,
9771019
)
9781020

979-
# 3 Compare
980-
print("torch out", pt_out)
981-
print("jax out", jax_out)
982-
np.testing.assert_allclose(to_jax(pt_out), jax_out, rtol=1e-2, atol=1e-2)
1021+
# 3. Compare
1022+
if check_norm:
1023+
expected = to_jax(pt_out) / jnp.linalg.norm(to_jax(pt_out))
1024+
actual = jax_out / jnp.linalg.norm(jax_out)
1025+
else:
1026+
expected = to_jax(pt_out)
1027+
actual = jax_out
1028+
1029+
print("torch out", expected)
1030+
print("jax out", actual)
1031+
np.testing.assert_allclose(expected, actual, rtol=1e-2, atol=1e-2)
9831032

9841033

9851034
if __name__ == "__main__":

tests/unit/train_compile_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,8 @@ def test_deepseek32(self):
767767
"megablox=True",
768768
"per_device_batch_size=1",
769769
"max_target_length=1024",
770-
"attention=dot_product", # TODO: update to flash attention when it's available.
770+
"attention=flash",
771+
"use_tokamax_splash=True",
771772
"dtype=bfloat16",
772773
"weight_dtype=bfloat16",
773774
# without_device_limit

0 commit comments

Comments
 (0)