2828from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
3030from tokamax ._src .ops .experimental .tpu .splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
31+ from tokamax ._src .ops .experimental .tpu .splash_attention import base as tokamax_base
3132from einops import rearrange
3233from .. import common_types , max_logging
3334
@@ -279,17 +280,49 @@ def _tpu_flash_attention(
279280 query = _reshape_data_for_flash (query , heads )
280281 key = _reshape_data_for_flash (key , heads )
281282 value = _reshape_data_for_flash (value , heads )
283+
284+ # Pre-padding and Ring Kernel creation outside shard_map
285+ if attention_kernel == "tokamax_ring" :
286+ block_q = max (block_sizes .block_q , block_sizes .block_q_dkv )
287+ block_kv = max (block_sizes .block_kv , block_sizes .block_kv_dkv )
288+
289+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_q , num_shards = num_fsdp_shards )
290+ key , _ , _ = _pad_data_for_flash (key , heads , block_kv , num_shards = num_fsdp_shards )
291+ value , _ , _ = _pad_data_for_flash (value , heads , block_kv , num_shards = num_fsdp_shards )
292+
293+ mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
294+ ring_kernel = tokamax_ring_attention_kernel .make_ring_attention (
295+ mask = mask ,
296+ is_mqa = False ,
297+ config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
298+ save_residuals = True ,
299+ ring_axis = "fsdp" ,
300+ q_seq_shards = num_fsdp_shards ,
301+ kv_seq_shards = num_fsdp_shards ,
302+ )
303+ kernel_spec = ring_kernel .manual_sharding_spec ()
304+ else :
305+ # Logic for other kernels remains unchanged regarding local padding
306+ ring_kernel = None
307+ kernel_spec = None
308+
282309 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
283310 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
284311
285312 @functools .partial (
286313 shard_map .shard_map ,
287314 mesh = mesh ,
288- in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
315+ in_specs = (q_axis_names , kv_axis_names , kv_axis_names , kernel_spec ),
289316 out_specs = q_axis_names ,
290317 check_rep = False ,
291318 )
292- def wrap_flash_attention (query , key , value ):
319+ def wrap_flash_attention (query , key , value , ring_kernel ):
320+
321+ if attention_kernel == "tokamax_ring" :
322+ # For bidirectional attention, segment_ids can be None to hit the performance shortcut
323+ segment_ids = None
324+ vmapped_splash = jax .vmap (ring_kernel , in_axes = (0 , 0 , 0 , None ))
325+ return vmapped_splash (query , key , value , segment_ids )
293326
294327 uses_fused_kernel = block_sizes .use_fused_bwd_kernel
295328 block_q_sizes = (
@@ -324,6 +357,7 @@ def wrap_flash_attention(query, key, value):
324357 kv_padded_len = key .shape [2 ]
325358 kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
326359 kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
360+
327361 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
328362
329363 # make_splash_mha is wrapped around shardmap and seq and head is already
@@ -336,15 +370,6 @@ def wrap_flash_attention(query, key, value):
336370 config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
337371 save_residuals = True if "ring" in attention_kernel else False ,
338372 )
339- elif attention_kernel == "tokamax_ring" :
340- mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
341- splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
342- mask = mask ,
343- is_mqa = False ,
344- config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
345- save_residuals = True ,
346- ring_axis = "fsdp" ,
347- )
348373 else :
349374 splash_kernel = splash_attention_kernel .make_splash_mha (
350375 mask = multi_head_mask ,
@@ -360,7 +385,7 @@ def wrap_flash_attention(query, key, value):
360385
361386 if not mask_padding_tokens :
362387 segment_ids = None
363- if attention_kernel in ["flash" , "tokamax_flash" , "tokamax_ring" ]:
388+ if attention_kernel in ["flash" , "tokamax_flash" ]:
364389 attention_output = vmapped_splash (query , key , value , segment_ids )
365390 else :
366391 if num_fsdp_shards > 1 :
@@ -412,7 +437,11 @@ def ring_scan_body(carry, _):
412437 "Warning, batch dimension should be shardable among the devices in data and fsdp"
413438 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
414439 )
415- x = wrap_flash_attention (query , key , value )
440+ x = wrap_flash_attention (query , key , value , ring_kernel )
441+
442+ if attention_kernel == "tokamax_ring" :
443+ x = x [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
444+
416445 x = _reshape_heads_to_head_dim (x )
417446
418447 return x
0 commit comments