@@ -507,6 +507,7 @@ def _ulysses_attention(
507507 mask_padding_tokens : bool = True ,
508508 residual_checkpoint_name : str | None = None ,
509509 attention_mask : jax .Array = None ,
510+ attention_kernel : str = "ulysses" ,
510511) -> jax .Array :
511512 """Ulysses sequence-parallel attention.
512513
@@ -530,7 +531,9 @@ def _ulysses_attention(
530531 "Ulysses attention requires the number of heads to be divisible by the context shard count, "
531532 f"got heads={ num_heads } and context_shards={ num_shards } ."
532533 )
533- block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , "flash" )
534+
535+ inner_kernel = "tokamax_flash" if attention_kernel == "tokamax_ulysses" else "flash"
536+ block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , inner_kernel )
534537
535538 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
536539 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
@@ -597,14 +600,26 @@ def wrap_ulysses_attention(query, key, value):
597600 if not mask_padding_tokens :
598601 segment_ids = None
599602
600- splash_kernel = splash_attention_kernel .make_splash_mha (
601- mask = multi_head_mask ,
602- head_shards = 1 ,
603- q_seq_shards = 1 ,
604- block_sizes = block_sizes ,
605- save_residuals = False ,
606- residual_checkpoint_name = residual_checkpoint_name ,
607- )
603+ if attention_kernel == "tokamax_ulysses" :
604+ mask = tokamax_splash_attention_mask .FullMask (
605+ _shape = (query .shape [2 ], key .shape [2 ]),
606+ )
607+ splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
608+ mask = mask ,
609+ q_seq_shards = 1 ,
610+ config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
611+ save_residuals = False ,
612+ )
613+ else :
614+ splash_kernel = splash_attention_kernel .make_splash_mha (
615+ mask = multi_head_mask ,
616+ head_shards = 1 ,
617+ q_seq_shards = 1 ,
618+ block_sizes = block_sizes ,
619+ save_residuals = False ,
620+ residual_checkpoint_name = residual_checkpoint_name ,
621+ )
622+
608623 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
609624 attention_output = vmapped_splash (query , key , value , segment_ids )
610625 attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
@@ -747,7 +762,7 @@ def _apply_attention(
747762 seq_len_idx = 1
748763 if query .ndim == 4 :
749764 seq_len_idx = 2
750- if attention_kernel in ["flash" , "tokamax_flash" , "ulysses" ]:
765+ if attention_kernel in ["flash" , "tokamax_flash" , "ulysses" , "tokamax_ulysses" ]:
751766 can_use_flash_attention = (
752767 query .shape [seq_len_idx ] >= flash_min_seq_length
753768 and key .shape [seq_len_idx ] >= flash_min_seq_length
@@ -759,7 +774,7 @@ def _apply_attention(
759774 return _apply_attention_dot (
760775 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
761776 )
762- elif attention_kernel == "ulysses" :
777+ elif attention_kernel in [ "ulysses" , "tokamax_ulysses" ] :
763778 return _ulysses_attention (
764779 query ,
765780 key * scale ,
@@ -773,6 +788,7 @@ def _apply_attention(
773788 mask_padding_tokens = mask_padding_tokens ,
774789 residual_checkpoint_name = residual_checkpoint_name ,
775790 attention_mask = attention_mask ,
791+ attention_kernel = attention_kernel ,
776792 )
777793 elif attention_kernel in ["flash" , "tokamax_flash" ]:
778794 return _tpu_flash_attention (
0 commit comments