3131from einops import rearrange
3232from .. import common_types , max_logging
3333
34+ from . import custom_splash_attention as custom_splash
3435from . import quantizations
3536from .modeling_flax_utils import get_activation
3637
@@ -521,6 +522,7 @@ def _ulysses_attention(
521522 mask_padding_tokens : bool = True ,
522523 residual_checkpoint_name : str | None = None ,
523524 attention_mask : jax .Array = None ,
525+ use_custom_kernel : bool = False ,
524526) -> jax .Array :
525527 """Ulysses sequence-parallel attention.
526528
@@ -544,7 +546,9 @@ def _ulysses_attention(
544546 "Ulysses attention requires the number of heads to be divisible by the context shard count, "
545547 f"got heads={ num_heads } and context_shards={ num_shards } ."
546548 )
547- block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , "flash" )
549+
550+ if not use_custom_kernel :
551+ block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , "flash" )
548552
549553 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
550554 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
@@ -563,65 +567,93 @@ def wrap_ulysses_attention(query, key, value):
563567 key = jax .lax .all_to_all (key , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
564568 value = jax .lax .all_to_all (value , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
565569
566- # Run the same local splash kernel as standard TPU flash attention, but now
567- # on full-sequence / fewer-heads tensors produced by the all-to-all above.
568- uses_fused_kernel = block_sizes .use_fused_bwd_kernel
569- block_q_sizes = (block_sizes .block_q , block_sizes .block_q_dkv )
570- block_kv_sizes = (block_sizes .block_kv , block_sizes .block_kv_dkv )
571- if uses_fused_kernel :
572- block_q_sizes += (block_sizes .block_q_dkv ,)
573- block_kv_sizes += (block_sizes .block_kv_dkv ,)
574- else :
575- block_q_sizes += (block_sizes .block_q_dq ,)
576- block_kv_sizes += (block_sizes .block_kv_dq ,)
570+ if use_custom_kernel :
571+ bq = 4864
572+ bkv = 1024
573+ bkv_compute = 1024
574+ bkv_compute_in = 1024
575+ heads_per_tile = 1
577576
578- block_q = max (* block_q_sizes )
579- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
580- block_kv = max (* block_kv_sizes )
581- key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
582- value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
577+ query_scaled = query * 1.44269504
583578
584- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
585- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
579+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , bq )
580+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , bkv )
581+ value , _ , _ = _pad_data_for_flash (value , heads , bkv )
586582
587- q_padded_len = query .shape [2 ]
588- q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
589- q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
583+ bsizes = custom_splash ._BlockSizes (block_q = bq , block_kv = bkv , block_kv_compute = bkv_compute )
590584
591- kv_padded_len = key .shape [2 ]
592- kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
593- kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
585+ splash_kernel = custom_splash .make_splash_mha (
586+ block_sizes = bsizes ,
587+ bkv_compute_in = bkv_compute_in ,
588+ orig_q_seq_len = query_seq_len ,
589+ orig_kv_seq_len = key_seq_len ,
590+ heads_per_tile = heads_per_tile ,
591+ )
594592
595- # Reuse the standard flash-attention masking convention by zeroing invalid
596- # KV positions in the segment ids passed down to splash.
597- if attention_mask is not None :
598- mask_len = min (key_seq_len , attention_mask .shape [1 ])
599- kv_mask_for_batch = attention_mask [0 , :mask_len ]
600- if key_seq_len > mask_len :
601- extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
602- kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 )
603- if kv_padded_len > key_seq_len :
604- padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
605- kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 )
593+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ))
594+ attention_output = vmapped_splash (query_scaled , key , value )
595+ attention_output = jnp .swapaxes (attention_output , 2 , 3 )
596+ attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
597+ else :
598+ # Run the same local splash kernel as standard TPU flash attention, but now
599+ # on full-sequence / fewer-heads tensors produced by the all-to-all above.
600+ uses_fused_kernel = block_sizes .use_fused_bwd_kernel
601+ block_q_sizes = (block_sizes .block_q , block_sizes .block_q_dkv )
602+ block_kv_sizes = (block_sizes .block_kv , block_sizes .block_kv_dkv )
603+ if uses_fused_kernel :
604+ block_q_sizes += (block_sizes .block_q_dkv ,)
605+ block_kv_sizes += (block_sizes .block_kv_dkv ,)
606606 else :
607- kv_mask_padded = kv_mask_for_batch
608- kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
607+ block_q_sizes += (block_sizes .block_q_dq ,)
608+ block_kv_sizes += (block_sizes .block_kv_dq ,)
609+
610+ block_q = max (* block_q_sizes )
611+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
612+ block_kv = max (* block_kv_sizes )
613+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
614+ value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
615+
616+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
617+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
618+
619+ q_padded_len = query .shape [2 ]
620+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
621+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
622+
623+ kv_padded_len = key .shape [2 ]
624+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
625+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
626+
627+ # Reuse the standard flash-attention masking convention by zeroing invalid
628+ # KV positions in the segment ids passed down to splash.
629+ if attention_mask is not None :
630+ mask_len = min (key_seq_len , attention_mask .shape [1 ])
631+ kv_mask_for_batch = attention_mask [0 , :mask_len ]
632+ if key_seq_len > mask_len :
633+ extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
634+ kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 )
635+ if kv_padded_len > key_seq_len :
636+ padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
637+ kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 )
638+ else :
639+ kv_mask_padded = kv_mask_for_batch
640+ kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
609641
610- segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
611- if not mask_padding_tokens :
612- segment_ids = None
642+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
643+ if not mask_padding_tokens :
644+ segment_ids = None
613645
614- splash_kernel = splash_attention_kernel .make_splash_mha (
615- mask = multi_head_mask ,
616- head_shards = 1 ,
617- q_seq_shards = 1 ,
618- block_sizes = block_sizes ,
619- save_residuals = False ,
620- residual_checkpoint_name = residual_checkpoint_name ,
621- )
622- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
623- attention_output = vmapped_splash (query , key , value , segment_ids )
624- attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
646+ splash_kernel = splash_attention_kernel .make_splash_mha (
647+ mask = multi_head_mask ,
648+ head_shards = 1 ,
649+ q_seq_shards = 1 ,
650+ block_sizes = block_sizes ,
651+ save_residuals = False ,
652+ residual_checkpoint_name = residual_checkpoint_name ,
653+ )
654+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
655+ attention_output = vmapped_splash (query , key , value , segment_ids )
656+ attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
625657
626658 # Restore the original layout expected by the rest of the model:
627659 # head-sharded / full-sequence -> sequence-sharded / full-heads.
@@ -763,7 +795,7 @@ def _apply_attention(
763795 seq_len_idx = 1
764796 if query .ndim == 4 :
765797 seq_len_idx = 2
766- if attention_kernel in ["flash" , "tokamax_flash" , "ulysses" ]:
798+ if attention_kernel in ["flash" , "tokamax_flash" , "ulysses" , "ulysses_custom" ]:
767799 can_use_flash_attention = (
768800 query .shape [seq_len_idx ] >= flash_min_seq_length
769801 and key .shape [seq_len_idx ] >= flash_min_seq_length
@@ -775,6 +807,22 @@ def _apply_attention(
775807 return _apply_attention_dot (
776808 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
777809 )
810+ elif attention_kernel == "ulysses_custom" :
811+ return _ulysses_attention (
812+ query ,
813+ key * scale ,
814+ value ,
815+ heads ,
816+ mesh ,
817+ axis_names_q ,
818+ axis_names_kv ,
819+ flash_block_sizes ,
820+ dtype ,
821+ mask_padding_tokens = mask_padding_tokens ,
822+ residual_checkpoint_name = residual_checkpoint_name ,
823+ attention_mask = attention_mask ,
824+ use_custom_kernel = True ,
825+ )
778826 elif attention_kernel == "ulysses" :
779827 return _ulysses_attention (
780828 query ,
0 commit comments