@@ -522,6 +522,7 @@ def _ulysses_attention(
522522 mask_padding_tokens : bool = True ,
523523 residual_checkpoint_name : str | None = None ,
524524 attention_mask : jax .Array = None ,
525+ use_custom_kernel : bool = False ,
525526) -> jax .Array :
526527 """Ulysses sequence-parallel attention.
527528
@@ -545,7 +546,9 @@ def _ulysses_attention(
545546 "Ulysses attention requires the number of heads to be divisible by the context shard count, "
546547 f"got heads={ num_heads } and context_shards={ num_shards } ."
547548 )
548- block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , "flash" )
549+
550+ if not use_custom_kernel :
551+ block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , "flash" )
549552
550553 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
551554 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
@@ -564,65 +567,93 @@ def wrap_ulysses_attention(query, key, value):
564567 key = jax .lax .all_to_all (key , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
565568 value = jax .lax .all_to_all (value , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
566569
567- # Run the same local splash kernel as standard TPU flash attention, but now
568- # on full-sequence / fewer-heads tensors produced by the all-to-all above.
569- uses_fused_kernel = block_sizes .use_fused_bwd_kernel
570- block_q_sizes = (block_sizes .block_q , block_sizes .block_q_dkv )
571- block_kv_sizes = (block_sizes .block_kv , block_sizes .block_kv_dkv )
572- if uses_fused_kernel :
573- block_q_sizes += (block_sizes .block_q_dkv ,)
574- block_kv_sizes += (block_sizes .block_kv_dkv ,)
575- else :
576- block_q_sizes += (block_sizes .block_q_dq ,)
577- block_kv_sizes += (block_sizes .block_kv_dq ,)
570+ if use_custom_kernel :
571+ bq = 4864
572+ bkv = 1024
573+ bkv_compute = 1024
574+ bkv_compute_in = 1024
575+ heads_per_tile = 1
578576
579- block_q = max (* block_q_sizes )
580- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
581- block_kv = max (* block_kv_sizes )
582- key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
583- value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
577+ query_scaled = query * 1.44269504
584578
585- mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
586- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
579+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , bq )
580+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , bkv )
581+ value , _ , _ = _pad_data_for_flash (value , heads , bkv )
587582
588- q_padded_len = query .shape [2 ]
589- q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
590- q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
583+ bsizes = custom_splash ._BlockSizes (block_q = bq , block_kv = bkv , block_kv_compute = bkv_compute )
591584
592- kv_padded_len = key .shape [2 ]
593- kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
594- kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
585+ splash_kernel = custom_splash .make_splash_mha (
586+ block_sizes = bsizes ,
587+ bkv_compute_in = bkv_compute_in ,
588+ orig_q_seq_len = query_seq_len ,
589+ orig_kv_seq_len = key_seq_len ,
590+ heads_per_tile = heads_per_tile ,
591+ )
595592
596- # Reuse the standard flash-attention masking convention by zeroing invalid
597- # KV positions in the segment ids passed down to splash.
598- if attention_mask is not None :
599- mask_len = min (key_seq_len , attention_mask .shape [1 ])
600- kv_mask_for_batch = attention_mask [0 , :mask_len ]
601- if key_seq_len > mask_len :
602- extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
603- kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 )
604- if kv_padded_len > key_seq_len :
605- padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
606- kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 )
593+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ))
594+ attention_output = vmapped_splash (query_scaled , key , value )
595+ attention_output = jnp .swapaxes (attention_output , 2 , 3 )
596+ attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
597+ else :
598+ # Run the same local splash kernel as standard TPU flash attention, but now
599+ # on full-sequence / fewer-heads tensors produced by the all-to-all above.
600+ uses_fused_kernel = block_sizes .use_fused_bwd_kernel
601+ block_q_sizes = (block_sizes .block_q , block_sizes .block_q_dkv )
602+ block_kv_sizes = (block_sizes .block_kv , block_sizes .block_kv_dkv )
603+ if uses_fused_kernel :
604+ block_q_sizes += (block_sizes .block_q_dkv ,)
605+ block_kv_sizes += (block_sizes .block_kv_dkv ,)
607606 else :
608- kv_mask_padded = kv_mask_for_batch
609- kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
607+ block_q_sizes += (block_sizes .block_q_dq ,)
608+ block_kv_sizes += (block_sizes .block_kv_dq ,)
609+
610+ block_q = max (* block_q_sizes )
611+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q )
612+ block_kv = max (* block_kv_sizes )
613+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_kv )
614+ value , _ , _ = _pad_data_for_flash (value , heads , block_kv )
615+
616+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
617+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
618+
619+ q_padded_len = query .shape [2 ]
620+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
621+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
622+
623+ kv_padded_len = key .shape [2 ]
624+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
625+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
626+
627+ # Reuse the standard flash-attention masking convention by zeroing invalid
628+ # KV positions in the segment ids passed down to splash.
629+ if attention_mask is not None :
630+ mask_len = min (key_seq_len , attention_mask .shape [1 ])
631+ kv_mask_for_batch = attention_mask [0 , :mask_len ]
632+ if key_seq_len > mask_len :
633+ extra_valid = jnp .ones ((key_seq_len - mask_len ,), dtype = jnp .int32 )
634+ kv_mask_for_batch = jnp .concatenate ([kv_mask_for_batch , extra_valid ], axis = 0 )
635+ if kv_padded_len > key_seq_len :
636+ padding = jnp .zeros ((kv_padded_len - key_seq_len ,), dtype = jnp .int32 )
637+ kv_mask_padded = jnp .concatenate ([kv_mask_for_batch , padding ], axis = 0 )
638+ else :
639+ kv_mask_padded = kv_mask_for_batch
640+ kv_segment_ids = (kv_segment_ids * kv_mask_padded ).astype (jnp .int32 )
610641
611- segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
612- if not mask_padding_tokens :
613- segment_ids = None
642+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
643+ if not mask_padding_tokens :
644+ segment_ids = None
614645
615- splash_kernel = splash_attention_kernel .make_splash_mha (
616- mask = multi_head_mask ,
617- head_shards = 1 ,
618- q_seq_shards = 1 ,
619- block_sizes = block_sizes ,
620- save_residuals = False ,
621- residual_checkpoint_name = residual_checkpoint_name ,
622- )
623- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
624- attention_output = vmapped_splash (query , key , value , segment_ids )
625- attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
646+ splash_kernel = splash_attention_kernel .make_splash_mha (
647+ mask = multi_head_mask ,
648+ head_shards = 1 ,
649+ q_seq_shards = 1 ,
650+ block_sizes = block_sizes ,
651+ save_residuals = False ,
652+ residual_checkpoint_name = residual_checkpoint_name ,
653+ )
654+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
655+ attention_output = vmapped_splash (query , key , value , segment_ids )
656+ attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
626657
627658 # Restore the original layout expected by the rest of the model:
628659 # head-sharded / full-sequence -> sequence-sharded / full-heads.
@@ -642,94 +673,6 @@ def wrap_ulysses_attention(query, key, value):
642673 return x
643674
644675
645- def _ulysses_custom_attention (
646- query : jax .Array ,
647- key : jax .Array ,
648- value : jax .Array ,
649- heads : int ,
650- mesh : Mesh ,
651- axis_names_q : AxisNames ,
652- axis_names_kv : AxisNames ,
653- flash_block_sizes : BlockSizes ,
654- dtype : jnp .dtype = jnp .float32 ,
655- mask_padding_tokens : bool = False ,
656- residual_checkpoint_name : str | None = None ,
657- attention_mask : jax .Array = None ,
658- ) -> jax .Array :
659- """Ulysses sequence-parallel attention with custom fast kernel."""
660- axis_name = "context"
661- num_shards = mesh .shape [axis_name ]
662-
663- # Reshape to [b, h, s, d] and pad sequence for even context-axis splitting.
664- query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_shards )
665- key , _ = _reshape_data_for_flash (key , heads , num_shards )
666- value , _ = _reshape_data_for_flash (value , heads , num_shards )
667- num_heads = query .shape [1 ]
668- if num_heads % num_shards != 0 :
669- raise ValueError (
670- "Ulysses attention requires the number of heads to be divisible by the context shard count, "
671- f"got heads={ num_heads } and context_shards={ num_shards } ."
672- )
673-
674- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
675- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
676-
677- @functools .partial (
678- jax .shard_map ,
679- mesh = mesh ,
680- in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
681- out_specs = q_axis_names ,
682- check_vma = False ,
683- )
684- def wrap_ulysses_attention (query , key , value ):
685- query = jax .lax .all_to_all (query , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
686- key = jax .lax .all_to_all (key , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
687- value = jax .lax .all_to_all (value , axis_name = axis_name , split_axis = 1 , concat_axis = 2 , tiled = True )
688-
689- bq = 2048
690- bkv = 2048
691- bkv_compute = 1024
692- bkv_compute_in = 256
693- heads_per_tile = 1
694-
695- query_scaled = query * 1.44269504
696-
697- query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , bq )
698- key , _ , key_seq_len = _pad_data_for_flash (key , heads , bkv )
699- value , _ , _ = _pad_data_for_flash (value , heads , bkv )
700-
701- bsizes = custom_splash ._BlockSizes (block_q = bq , block_kv = bkv , block_kv_compute = bkv_compute )
702-
703- splash_kernel = custom_splash .make_splash_mha (
704- block_sizes = bsizes ,
705- bkv_compute_in = bkv_compute_in ,
706- orig_q_seq_len = query_seq_len ,
707- orig_kv_seq_len = key_seq_len ,
708- heads_per_tile = heads_per_tile ,
709- )
710-
711- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 ))
712- attention_output = vmapped_splash (query_scaled , key , value )
713- attention_output = jnp .swapaxes (attention_output , 2 , 3 )
714-
715- attention_output = attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
716-
717- attention_output = jax .lax .all_to_all (attention_output , axis_name = axis_name , split_axis = 2 , concat_axis = 1 , tiled = True )
718- return attention_output
719-
720- devices_in_batch_sharding = mesh .shape ["data" ] * (mesh .shape ["fsdp" ] if "fsdp" in mesh .shape else 1 )
721- if not (query .shape [0 ] / devices_in_batch_sharding ).is_integer ():
722- max_logging .log (
723- "Warning, batch dimension should be shardable among the devices in data and fsdp"
724- f" axis, batch dimension: { query .shape [0 ]} , devices_in_batch_sharding: { devices_in_batch_sharding } "
725- )
726- x = wrap_ulysses_attention (query , key , value )
727- x = x [:, :, :orig_q_seq_len , :]
728- x = _reshape_heads_to_head_dim (x )
729-
730- return x
731-
732-
733676def _apply_attention_dot (
734677 query : Array ,
735678 key : Array ,
@@ -865,7 +808,7 @@ def _apply_attention(
865808 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
866809 )
867810 elif attention_kernel == "ulysses_custom" :
868- return _ulysses_custom_attention (
811+ return _ulysses_attention (
869812 query ,
870813 key * scale ,
871814 value ,
@@ -878,6 +821,7 @@ def _apply_attention(
878821 mask_padding_tokens = mask_padding_tokens ,
879822 residual_checkpoint_name = residual_checkpoint_name ,
880823 attention_mask = attention_mask ,
824+ use_custom_kernel = True ,
881825 )
882826 elif attention_kernel == "ulysses" :
883827 return _ulysses_attention (
0 commit comments