2525from jax .experimental import shard_map
2626from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
2727from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
28+ from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
29+ from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
2830from einops import rearrange
2931from .. import common_types , max_logging
3032
4648EMBED = common_types .EMBED
4749Quant = quantizations .AqtQuantization
4850
51+ SELF_ATTN_HEAD = common_types .SELF_ATTN_HEAD
52+ SELF_ATTN_Q_LENGTH = common_types .SELF_ATTN_Q_LENGTH
53+ SELF_ATTN_KV_LENGTH = common_types .SELF_ATTN_KV_LENGTH
54+ CROSS_ATTN_HEAD = common_types .CROSS_ATTN_HEAD
55+ CROSS_ATTN_Q_LENGTH = common_types .CROSS_ATTN_Q_LENGTH
56+ CROSS_ATTN_KV_LENGTH = common_types .CROSS_ATTN_KV_LENGTH
57+
4958
5059def _maybe_aqt_einsum (quant : Quant ):
5160 return jnp .einsum if quant is None else quant .einsum ()
@@ -163,6 +172,40 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
163172
164173 return tensor , kv_size , seq_len
165174
175+ def convert_to_tokamax_splash_config ( block_sizes : BlockSizes ,
176+ q_layout : tokamax_splash_attention_kernel .QKVLayout = tokamax_splash_attention_kernel .QKVLayout .HEAD_DIM_MINOR ,
177+ k_layout : tokamax_splash_attention_kernel .QKVLayout = tokamax_splash_attention_kernel .QKVLayout .HEAD_DIM_MINOR ,
178+ v_layout : tokamax_splash_attention_kernel .QKVLayout = tokamax_splash_attention_kernel .QKVLayout .HEAD_DIM_MINOR ,
179+ residual_checkpoint_name : str | None = None ,
180+ attn_logits_soft_cap : float | None = None ,
181+ fuse_reciprocal : bool = True ,
182+ use_base2_exp : bool = False ,
183+ max_logit_const : float | None = None ,
184+ interpret : bool = False ,
185+ dq_reduction_steps : int | None = None ) -> tokamax_splash_attention_kernel .SplashConfig :
186+ assert block_sizes .use_fused_bwd_kernel , "Tokamax Splash attention only supports fused bwd kernel."
187+ return tokamax_splash_attention_kernel .SplashConfig (
188+ block_q = block_sizes .block_q ,
189+ block_kv = block_sizes .block_kv ,
190+ block_kv_compute = block_sizes .block_kv_compute ,
191+ block_q_dkv = block_sizes .block_q_dkv ,
192+ block_kv_dkv = block_sizes .block_kv_dkv ,
193+ block_kv_dkv_compute = block_sizes .block_kv_dkv_compute ,
194+ block_q_dq = None if block_sizes .use_fused_bwd_kernel else block_sizes .block_q_dq ,
195+ block_kv_dq = None if block_sizes .use_fused_bwd_kernel else block_sizes .block_kv_dq ,
196+ use_fused_bwd_kernel = block_sizes .use_fused_bwd_kernel ,
197+ q_layout = q_layout ,
198+ k_layout = k_layout ,
199+ v_layout = v_layout ,
200+ residual_checkpoint_name = residual_checkpoint_name ,
201+ attn_logits_soft_cap = attn_logits_soft_cap ,
202+ fuse_reciprocal = fuse_reciprocal ,
203+ use_base2_exp = use_base2_exp ,
204+ max_logit_const = max_logit_const ,
205+ interpret = interpret ,
206+ dq_reduction_steps = dq_reduction_steps ,
207+ )
208+
166209
167210def _tpu_flash_attention (
168211 query : jax .Array ,
@@ -175,6 +218,7 @@ def _tpu_flash_attention(
175218 flash_block_sizes : BlockSizes ,
176219 dtype : jnp .dtype = jnp .float32 ,
177220 attention_kernel : str = "flash" ,
221+ mask_padding_tokens : bool = True ,
178222 residual_checkpoint_name : str | None = None ,
179223) -> jax .Array :
180224 """TPU Flash Attention"""
@@ -186,17 +230,19 @@ def _tpu_flash_attention(
186230 kv_max_block_size = key .shape [1 ]
187231 else :
188232 kv_max_block_size = q_max_block_size
189- if flash_block_sizes :
233+ # ensure that for cross attention we override the block sizes.
234+ if flash_block_sizes and key .shape [1 ] == query .shape [1 ]:
190235 block_sizes = flash_block_sizes
191236 else :
237+ block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
192238 block_sizes = splash_attention_kernel .BlockSizes (
193- block_q = min ( q_max_block_size , query . shape [ 2 ]) ,
239+ block_q = block_size_q ,
194240 block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
195241 block_kv = min (kv_max_block_size , key .shape [2 ]),
196- block_q_dkv = min ( q_max_block_size , query . shape [ 2 ]) ,
242+ block_q_dkv = block_size_q ,
197243 block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
198244 block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
199- block_q_dq = None if attention_kernel == "tokamax_flash" else block_sizes . block_q_dq ,
245+ block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
200246 block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
201247 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
202248 )
@@ -253,17 +299,28 @@ def wrap_flash_attention(query, key, value):
253299
254300 # make_splash_mha is wrapped around shardmap and seq and head is already
255301 # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
256- splash_kernel = splash_attention_kernel .make_splash_mha (
257- mask = multi_head_mask ,
258- head_shards = 1 , # the sizes of the axis is sharding over heads
259- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
260- block_sizes = block_sizes ,
261- save_residuals = True if attention_kernel == "ring" else False ,
262- residual_checkpoint_name = residual_checkpoint_name ,
263- )
302+ if attention_kernel == "tokamax_flash" :
303+ mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
304+ splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
305+ mask = mask ,
306+ q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
307+ config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
308+ save_residuals = True if attention_kernel == "ring" else False ,
309+ )
310+ else :
311+ splash_kernel = splash_attention_kernel .make_splash_mha (
312+ mask = multi_head_mask ,
313+ head_shards = 1 , # the sizes of the axis is sharding over heads
314+ q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
315+ block_sizes = block_sizes ,
316+ save_residuals = True if attention_kernel == "ring" else False ,
317+ residual_checkpoint_name = residual_checkpoint_name
318+ )
264319 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
265320
266- if attention_kernel == "flash" :
321+ if not mask_padding_tokens :
322+ segment_ids = None
323+ if attention_kernel in ["flash" , "tokamax_flash" ]:
267324 attention_output = vmapped_splash (query , key , value , segment_ids )
268325 else :
269326 if num_fsdp_shards > 1 :
@@ -302,6 +359,8 @@ def ring_scan_body(carry, _):
302359 (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
303360
304361 attention_output = o_final / l_final [..., None ]
362+ else :
363+ raise ValueError ("ring attention requires fsdp > 1" )
305364
306365 return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
307366
@@ -442,14 +501,15 @@ def _apply_attention(
442501 axis_names_kv : AxisNames ,
443502 flash_block_sizes : BlockSizes ,
444503 dpa_layer : Callable ,
504+ mask_padding_tokens : bool = True ,
445505 residual_checkpoint_name : str | None = None ,
446506):
447507 """Routes to different attention kernels."""
448508 _check_attention_inputs (query , key , value )
449509 seq_len_idx = 1
450510 if query .ndim == 4 :
451511 seq_len_idx = 2
452- if attention_kernel == "flash" :
512+ if attention_kernel in [ "flash" , "tokamax_flash" ] :
453513 can_use_flash_attention = (
454514 query .shape [seq_len_idx ] >= flash_min_seq_length
455515 and key .shape [seq_len_idx ] >= flash_min_seq_length
@@ -461,7 +521,7 @@ def _apply_attention(
461521 return _apply_attention_dot (
462522 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
463523 )
464- elif attention_kernel == "flash" :
524+ elif attention_kernel in [ "flash" , "tokamax_flash" ] :
465525 return _tpu_flash_attention (
466526 query ,
467527 key * scale ,
@@ -472,11 +532,14 @@ def _apply_attention(
472532 axis_names_kv ,
473533 flash_block_sizes ,
474534 dtype ,
535+ attention_kernel ,
536+ mask_padding_tokens = mask_padding_tokens ,
475537 residual_checkpoint_name = residual_checkpoint_name ,
476538 )
477539 elif attention_kernel == "ring" :
478540 return _tpu_flash_attention (
479- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel
541+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
542+ mask_padding_tokens = mask_padding_tokens ,
480543 )
481544 elif attention_kernel == "cudnn_flash_te" :
482545 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
@@ -607,6 +670,7 @@ def __init__(
607670 flash_block_sizes : BlockSizes = None ,
608671 dtype : DType = jnp .float32 ,
609672 quant : Quant = None ,
673+ mask_padding_tokens : bool = True ,
610674 residual_checkpoint_name : str | None = None ,
611675 ):
612676 self .dpa_layer = None
@@ -627,6 +691,7 @@ def __init__(
627691 self .flash_block_sizes = flash_block_sizes
628692 self .dtype = dtype
629693 self .quant = quant
694+ self .mask_padding_tokens = mask_padding_tokens
630695 self .residual_checkpoint_name = residual_checkpoint_name
631696
632697 def apply_attention (self , query : Array , key : Array , value : Array ):
@@ -648,6 +713,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
648713 axis_names_kv = self .axis_names_kv ,
649714 flash_block_sizes = self .flash_block_sizes ,
650715 dpa_layer = self .dpa_layer ,
716+ mask_padding_tokens = self .mask_padding_tokens ,
651717 residual_checkpoint_name = self .residual_checkpoint_name ,
652718 )
653719
@@ -737,6 +803,8 @@ def __init__(
737803 precision : jax .lax .Precision = None ,
738804 qkv_bias : bool = False ,
739805 quant : Quant = None ,
806+ is_self_attention : bool = True ,
807+ mask_padding_tokens : bool = True ,
740808 residual_checkpoint_name : str | None = None ,
741809 enable_jax_named_scopes : bool = False ,
742810 ):
@@ -750,11 +818,19 @@ def __init__(
750818 self .inner_dim = dim_head * heads
751819 scale = dim_head ** - 0.5
752820 self .qk_norm = qk_norm
753- self . enable_jax_named_scopes = enable_jax_named_scopes
821+
754822 self .query_axis_names = query_axis_names
755823 self .key_axis_names = key_axis_names
756824 self .value_axis_names = value_axis_names
757825 self .out_axis_names = out_axis_names
826+ self .enable_jax_named_scopes = enable_jax_named_scopes
827+
828+ if is_self_attention :
829+ axis_names_q = (BATCH , SELF_ATTN_HEAD , SELF_ATTN_Q_LENGTH , D_KV )
830+ axis_names_kv = (BATCH , SELF_ATTN_HEAD , SELF_ATTN_KV_LENGTH , D_KV )
831+ else :
832+ axis_names_q = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_Q_LENGTH , D_KV )
833+ axis_names_kv = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_KV_LENGTH , D_KV )
758834
759835 self .attention_op = NNXAttentionOp (
760836 mesh = mesh ,
@@ -765,10 +841,13 @@ def __init__(
765841 use_memory_efficient_attention = use_memory_efficient_attention ,
766842 split_head_dim = split_head_dim ,
767843 float32_qk_product = False ,
844+ axis_names_q = axis_names_q ,
845+ axis_names_kv = axis_names_kv ,
768846 flash_min_seq_length = flash_min_seq_length ,
769847 flash_block_sizes = flash_block_sizes ,
770848 dtype = dtype ,
771849 quant = quant ,
850+ mask_padding_tokens = mask_padding_tokens ,
772851 residual_checkpoint_name = residual_checkpoint_name ,
773852 )
774853 # None axes corresponds to the stacked weights across all blocks
@@ -1579,4 +1658,4 @@ def setup(self):
15791658 def __call__ (self , hidden_states , deterministic = True ):
15801659 hidden_states = self .proj (hidden_states )
15811660 hidden_linear , hidden_gelu = jnp .split (hidden_states , 2 , axis = 2 )
1582- return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
1661+ return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
0 commit comments