2424from jax .experimental import shard_map
2525from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
2626from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
27+ from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
28+ from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
2729from einops import rearrange
2830from .. import common_types , max_logging
2931
@@ -169,6 +171,40 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
169171
170172 return tensor , kv_size , seq_len
171173
174+ def convert_to_tokamax_splash_config ( block_sizes : BlockSizes ,
175+ q_layout : tokamax_splash_attention_kernel .QKVLayout = tokamax_splash_attention_kernel .QKVLayout .HEAD_DIM_MINOR ,
176+ k_layout : tokamax_splash_attention_kernel .QKVLayout = tokamax_splash_attention_kernel .QKVLayout .HEAD_DIM_MINOR ,
177+ v_layout : tokamax_splash_attention_kernel .QKVLayout = tokamax_splash_attention_kernel .QKVLayout .HEAD_DIM_MINOR ,
178+ residual_checkpoint_name : str | None = None ,
179+ attn_logits_soft_cap : float | None = None ,
180+ fuse_reciprocal : bool = True ,
181+ use_base2_exp : bool = False ,
182+ max_logit_const : float | None = None ,
183+ interpret : bool = False ,
184+ dq_reduction_steps : int | None = None ) -> tokamax_splash_attention_kernel .SplashConfig :
185+ assert block_sizes .use_fused_bwd_kernel , "Tokamax Splash attention only supports fused bwd kernel."
186+ return tokamax_splash_attention_kernel .SplashConfig (
187+ block_q = block_sizes .block_q ,
188+ block_kv = block_sizes .block_kv ,
189+ block_kv_compute = block_sizes .block_kv_compute ,
190+ block_q_dkv = block_sizes .block_q_dkv ,
191+ block_kv_dkv = block_sizes .block_kv_dkv ,
192+ block_kv_dkv_compute = block_sizes .block_kv_dkv_compute ,
193+ block_q_dq = None if block_sizes .use_fused_bwd_kernel else block_sizes .block_q_dq ,
194+ block_kv_dq = None if block_sizes .use_fused_bwd_kernel else block_sizes .block_kv_dq ,
195+ use_fused_bwd_kernel = block_sizes .use_fused_bwd_kernel ,
196+ q_layout = q_layout ,
197+ k_layout = k_layout ,
198+ v_layout = v_layout ,
199+ residual_checkpoint_name = residual_checkpoint_name ,
200+ attn_logits_soft_cap = attn_logits_soft_cap ,
201+ fuse_reciprocal = fuse_reciprocal ,
202+ use_base2_exp = use_base2_exp ,
203+ max_logit_const = max_logit_const ,
204+ interpret = interpret ,
205+ dq_reduction_steps = dq_reduction_steps ,
206+ )
207+
172208
173209def _tpu_flash_attention (
174210 query : jax .Array ,
@@ -203,8 +239,9 @@ def _tpu_flash_attention(
203239 block_q_dkv = min (q_max_block_size , query .shape [2 ]),
204240 block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
205241 block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
206- block_q_dq = min (q_max_block_size , query .shape [2 ]),
207- block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
242+ block_q_dq = None if attention_kernel == "tokamax_flash" else block_sizes .block_q_dq ,
243+ block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
244+ use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
208245 )
209246 num_fsdp_shards = mesh .shape ["fsdp" ]
210247 query = _reshape_data_for_flash (query , heads )
@@ -240,18 +277,27 @@ def wrap_flash_attention(query, key, value):
240277
241278 # make_splash_mha is wrapped around shardmap and seq and head is already
242279 # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
243- splash_kernel = splash_attention_kernel .make_splash_mha (
244- mask = multi_head_mask ,
245- head_shards = 1 , # the sizes of the axis is sharding over heads
246- q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
247- block_sizes = block_sizes ,
248- save_residuals = True if attention_kernel == "ring" else False ,
249- )
280+ if attention_kernel == "tokamax_flash" :
281+ mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
282+ splash_kernel = tokamax_splash_attention_kernel .make_splash_mha (
283+ mask = mask ,
284+ q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
285+ config = convert_to_tokamax_splash_config (block_sizes ),
286+ save_residuals = True if attention_kernel == "ring" else False ,
287+ )
288+ else :
289+ splash_kernel = splash_attention_kernel .make_splash_mha (
290+ mask = multi_head_mask ,
291+ head_shards = 1 , # the sizes of the axis is sharding over heads
292+ q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
293+ block_sizes = block_sizes ,
294+ save_residuals = True if attention_kernel == "ring" else False ,
295+ )
250296 vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
251297
252298 if not mask_padding_tokens :
253299 segment_ids = None
254- if attention_kernel == "flash" :
300+ if attention_kernel in [ "flash" , "tokamax_flash" ] :
255301 attention_output = vmapped_splash (query , key , value , segment_ids )
256302 else :
257303 if num_fsdp_shards > 1 :
@@ -439,7 +485,7 @@ def _apply_attention(
439485 seq_len_idx = 1
440486 if query .ndim == 4 :
441487 seq_len_idx = 2
442- if attention_kernel == "flash" :
488+ if attention_kernel in [ "flash" , "tokamax_flash" ] :
443489 can_use_flash_attention = (
444490 query .shape [seq_len_idx ] >= flash_min_seq_length
445491 and key .shape [seq_len_idx ] >= flash_min_seq_length
@@ -451,7 +497,7 @@ def _apply_attention(
451497 return _apply_attention_dot (
452498 query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention
453499 )
454- elif attention_kernel == "flash" :
500+ elif attention_kernel in [ "flash" , "tokamax_flash" ] :
455501 return _tpu_flash_attention (
456502 query ,
457503 key * scale ,
0 commit comments