3131from einops import rearrange
3232from .. import common_types , max_logging
3333
34+ from . import custom_splash_attention as custom_splash
3435from . import quantizations
3536from .modeling_flax_utils import get_activation
3637
@@ -641,6 +642,94 @@ def wrap_ulysses_attention(query, key, value):
641642 return x
642643
643644
645+ def _ulysses_custom_attention (
646+ query : jax .Array ,
647+ key : jax .Array ,
648+ value : jax .Array ,
649+ heads : int ,
650+ mesh : Mesh ,
651+ axis_names_q : AxisNames ,
652+ axis_names_kv : AxisNames ,
653+ flash_block_sizes : BlockSizes ,
654+ dtype : jnp .dtype = jnp .float32 ,
655+ mask_padding_tokens : bool = False ,
656+ residual_checkpoint_name : str | None = None ,
657+ attention_mask : jax .Array = None ,
658+ ) -> jax .Array :
659+ """Ulysses sequence-parallel attention with custom fast kernel."""
660+ axis_name = "context"
661+ num_shards = mesh .shape [axis_name ]
662+
663+ # Reshape to [b, h, s, d] and pad sequence for even context-axis splitting.
664+ query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_shards )
665+ key , _ = _reshape_data_for_flash (key , heads , num_shards )
666+ value , _ = _reshape_data_for_flash (value , heads , num_shards )
667+ num_heads = query .shape [1 ]
668+ if num_heads % num_shards != 0 :
669+ raise ValueError (
670+ "Ulysses attention requires the number of heads to be divisible by the context shard count, "
671+ f"got heads={ num_heads } and context_shards={ num_shards } ."
672+ )
673+
674+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
675+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
676+
677+ @functools .partial (
678+ jax .shard_map ,
679+ mesh = mesh ,
680+ in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
681+ out_specs = q_axis_names ,
682+ check_vma = False ,
683+ )
684+ def wrap_ulysses_attention (query , key , value ):
685+ query = jax .lax .all_to_all (query , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
686+ key = jax .lax .all_to_all (key , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
687+ value = jax .lax .all_to_all (value , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
688+
689+ bq = 2048
690+ bkv = 2048
691+ bkv_compute = 1024
692+ bkv_compute_in = 256
693+ heads_per_tile = 1
694+
695+ query_scaled = query * 1.44269504
696+
697+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , bq )
698+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , bkv )
699+ value , _ , _ = _pad_data_for_flash (value , heads , bkv )
700+
701+ bsizes = custom_splash ._BlockSizes (block_q = bq , block_kv = bkv , block_kv_compute = bkv_compute )
702+
703+ splash_kernel = custom_splash .make_splash_mha (
704+ block_sizes = bsizes ,
705+ bkv_compute_in = bkv_compute_in ,
706+ orig_q_seq_len = query_seq_len ,
707+ orig_kv_seq_len = key_seq_len ,
708+ heads_per_tile = heads_per_tile ,
709+ )
710+
711+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ))
712+ attention_output = vmapped_splash (query_scaled , key , value )
713+ attention_output = jnp .swapaxes (attention_output , 2 , 3 )
714+
715+ attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
716+
717+ attention_output = jax .lax .all_to_all (attention_output , axis_name = axis_name , split_axis = 2 , concat_axis = 1 , tiled = True )
718+ return attention_output
719+
720+ devices_in_batch_sharding = mesh .shape ["data" ] * (mesh .shape ["fsdp" ] if "fsdp" in mesh .shape else 1 )
721+ if not (query .shape [0 ] / devices_in_batch_sharding ).is_integer ():
722+ max_logging .log (
723+ "Warning, batch dimension should be shardable among the devices in data and fsdp"
724+ f" axis, batch dimension: { query .shape [0 ]} , devices_in_batch_sharding: { devices_in_batch_sharding } "
725+ )
726+ x = wrap_ulysses_attention (query , key , value )
727+ x = x [:, :, :orig_q_seq_len , :]
728+ x = _reshape_heads_to_head_dim (x )
729+
730+ return x
731+
732+
644733def _apply_attention_dot (
645734 query : Array ,
646735 key : Array ,
@@ -763,7 +852,7 @@ def _apply_attention(
763852 seq_len_idx = 1
764853 if query .ndim == 4 :
765854 seq_len_idx = 2
766- if attention_kernel in ["flash" , "tokamax_flash" , "ulysses" ]:
855+ if attention_kernel in ["flash" , "tokamax_flash" , "ulysses" , "ulysses_custom" ]:
767856 can_use_flash_attention = (
768857 query .shape [seq_len_idx ] >= flash_min_seq_length
769858 and key .shape [seq_len_idx ] >= flash_min_seq_length
@@ -775,6 +864,21 @@ def _apply_attention(
775864 return _apply_attention_dot (
776865 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
777866 )
867+ elif attention_kernel == "ulysses_custom" :
868+ return _ulysses_custom_attention (
869+ query ,
870+ key * scale ,
871+ value ,
872+ heads ,
873+ mesh ,
874+ axis_names_q ,
875+ axis_names_kv ,
876+ flash_block_sizes ,
877+ dtype ,
878+ mask_padding_tokens = mask_padding_tokens ,
879+ residual_checkpoint_name = residual_checkpoint_name ,
880+ attention_mask = attention_mask ,
881+ )
778882 elif attention_kernel == "ulysses" :
779883 return _ulysses_attention (
780884 query ,
0 commit comments