@@ -227,6 +227,38 @@ def convert_to_tokamax_splash_config(
227227 )
228228
229229
230+ def _resolve_tpu_attention_block_sizes (
231+ query_seq_len : int ,
232+ kv_seq_len : int ,
233+ flash_block_sizes : BlockSizes ,
234+ dtype : jnp .dtype ,
235+ attention_kernel : str = "flash" ,
236+ ) -> BlockSizes :
237+ """Resolve TPU splash attention block sizes for self- and cross-attention."""
238+ q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
239+ is_cross_attention = kv_seq_len != query_seq_len
240+ if is_cross_attention :
241+ kv_max_block_size = ((kv_seq_len + 127 ) // 128 ) * 128
242+ else :
243+ kv_max_block_size = q_max_block_size
244+
245+ if flash_block_sizes and not is_cross_attention :
246+ return flash_block_sizes
247+
248+ block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
249+ return splash_attention_kernel .BlockSizes (
250+ block_q = block_size_q ,
251+ block_kv_compute = min (kv_max_block_size , kv_seq_len ),
252+ block_kv = min (kv_max_block_size , kv_seq_len ),
253+ block_q_dkv = block_size_q ,
254+ block_kv_dkv = min (kv_max_block_size , kv_seq_len ),
255+ block_kv_dkv_compute = min (kv_max_block_size , query_seq_len ),
256+ block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
257+ block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query_seq_len ),
258+ use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
259+ )
260+
261+
230262def _tpu_flash_attention (
231263 query : jax .Array ,
232264 key : jax .Array ,
@@ -244,32 +276,17 @@ def _tpu_flash_attention(
244276) -> jax .Array :
245277 """TPU Flash Attention"""
246278
247- q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
248- # This is the case for cross-attn.
249- if key .shape [1 ] != query .shape [1 ]:
250- kv_max_block_size = ((key .shape [1 ] + 127 ) // 128 ) * 128
251- else :
252- kv_max_block_size = q_max_block_size
253- # ensure that for cross attention we override the block sizes.
254- if flash_block_sizes and key .shape [1 ] == query .shape [1 ]:
255- block_sizes = flash_block_sizes
256- else :
257- block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
258- block_sizes = splash_attention_kernel .BlockSizes (
259- block_q = block_size_q ,
260- block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
261- block_kv = min (kv_max_block_size , key .shape [2 ]),
262- block_q_dkv = block_size_q ,
263- block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
264- block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
265- block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
266- block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
267- use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
268- )
269279 num_context_shards = mesh .shape ["context" ]
270280 query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_context_shards )
271281 key , _ = _reshape_data_for_flash (key , heads , num_context_shards )
272282 value , _ = _reshape_data_for_flash (value , heads , num_context_shards )
283+ block_sizes = _resolve_tpu_attention_block_sizes (
284+ query_seq_len = query .shape [2 ],
285+ kv_seq_len = key .shape [2 ],
286+ flash_block_sizes = flash_block_sizes ,
287+ dtype = dtype ,
288+ attention_kernel = attention_kernel ,
289+ )
273290
274291 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
275292 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
@@ -425,6 +442,7 @@ def ring_scan_body(carry, _):
425442# Ulysses sequence-parallel attention
426443# ---------------------------------------------------------------------------
427444
445+
428446def _ulysses_attention (
429447 query : jax .Array ,
430448 key : jax .Array ,
@@ -456,53 +474,41 @@ def _ulysses_attention(
456474 query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_shards )
457475 key , _ = _reshape_data_for_flash (key , heads , num_shards )
458476 value , _ = _reshape_data_for_flash (value , heads , num_shards )
477+ num_heads = query .shape [1 ]
478+ # Ulysses only redistributes existing heads across the context mesh; unlike
479+ # the earlier draft, we fail fast instead of padding synthetic heads.
480+ if num_heads % num_shards != 0 :
481+ raise ValueError (
482+ "Ulysses attention requires the number of heads to be divisible by the context shard count, "
483+ f"got heads={ num_heads } and context_shards={ num_shards } ."
484+ )
459485
460486 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
461487 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
462488
463- # Pre-compute block sizes outside shard_map (uses global shapes).
464- q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
465- if key .shape [2 ] != query .shape [2 ]:
466- kv_max_block_size = ((key .shape [2 ] + 127 ) // 128 ) * 128
467- else :
468- kv_max_block_size = q_max_block_size
469- if flash_block_sizes and key .shape [2 ] == query .shape [2 ]:
470- block_sizes = flash_block_sizes
471- else :
472- block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
473- block_sizes = splash_attention_kernel .BlockSizes (
474- block_q = block_size_q ,
475- block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
476- block_kv = min (kv_max_block_size , key .shape [2 ]),
477- block_q_dkv = block_size_q ,
478- block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
479- block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
480- block_q_dq = block_size_q ,
481- block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
482- use_fused_bwd_kernel = False ,
483- )
489+ block_sizes = _resolve_tpu_attention_block_sizes (
490+ query_seq_len = query .shape [2 ],
491+ kv_seq_len = key .shape [2 ],
492+ flash_block_sizes = flash_block_sizes ,
493+ dtype = dtype ,
494+ )
484495
485496 @functools .partial (
486- shard_map .shard_map ,
497+ jax .shard_map ,
487498 mesh = mesh ,
488499 in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
489500 out_specs = q_axis_names ,
490- check_rep = False ,
501+ check_vma = False ,
491502 )
492503 def wrap_ulysses_attention (query , key , value ):
493- # --- Step 1: all-to-all sequence-sharded -> head-sharded ---
494- original_q_heads = query .shape [1 ]
495- head_pad = (- original_q_heads ) % num_shards
496- if head_pad :
497- query = jnp .pad (query , ((0 , 0 ), (0 , head_pad ), (0 , 0 ), (0 , 0 )))
498- key = jnp .pad (key , ((0 , 0 ), (0 , head_pad ), (0 , 0 ), (0 , 0 )))
499- value = jnp .pad (value , ((0 , 0 ), (0 , head_pad ), (0 , 0 ), (0 , 0 )))
500-
504+ # Swap sharding modes: each device gives up a slice of sequence and gathers
505+ # a slice of heads, so the local splash kernel sees the full sequence.
501506 query = jax .lax .all_to_all (query , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
502507 key = jax .lax .all_to_all (key , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
503508 value = jax .lax .all_to_all (value , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
504509
505- # --- Step 2: local flash attention (full sequence, subset of heads) ---
510+ # Run the same local splash kernel as standard TPU flash attention, but now
511+ # on full-sequence / fewer-heads tensors produced by the all-to-all above.
506512 uses_fused_kernel = block_sizes .use_fused_bwd_kernel
507513 block_q_sizes = (block_sizes .block_q , block_sizes .block_q_dkv )
508514 block_kv_sizes = (block_sizes .block_kv , block_sizes .block_kv_dkv )
@@ -530,6 +536,8 @@ def wrap_ulysses_attention(query, key, value):
530536 kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
531537 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
532538
539+ # Reuse the standard flash-attention masking convention by zeroing invalid
540+ # KV positions in the segment ids passed down to splash.
533541 if attention_mask is not None :
534542 mask_len = min (key_seq_len , attention_mask .shape [1 ])
535543 kv_mask_for_batch = attention_mask [0 , :mask_len ]
@@ -559,11 +567,9 @@ def wrap_ulysses_attention(query, key, value):
559567 attention_output = vmapped_splash (query , key , value , segment_ids )
560568 attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
561569
562- # --- Step 3: all-to-all head-sharded -> sequence-sharded ---
563- attention_output = jax .lax .all_to_all (
564- attention_output , axis_name = axis_name , split_axis = 2 , concat_axis = 1 , tiled = True
565- )
566- attention_output = attention_output [:, :original_q_heads , :, :]
570+ # Restore the original layout expected by the rest of the model:
571+ # head-sharded / full-sequence -> sequence-sharded / full-heads.
572+ attention_output = jax .lax .all_to_all (attention_output , axis_name = axis_name , split_axis = 2 , concat_axis = 1 , tiled = True )
567573 return attention_output
568574
569575 devices_in_data_context = mesh .shape ["data" ] * num_shards
0 commit comments