2727from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_kernel
2828from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929from tokamax ._src .ops .experimental .tpu .splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
30+ from tokamax ._src .ops .experimental .tpu .splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
3031from einops import rearrange
3132from .. import common_types , max_logging
3233
@@ -305,62 +306,92 @@ def wrap_flash_attention(query, key, value):
305306 mask = mask ,
306307 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
307308 config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
308- save_residuals = True if attention_kernel == "ring" else False ,
309+ save_residuals = True if "ring" in attention_kernel else False ,
310+ )
311+ elif attention_kernel == "tokamax_ring" :
312+ mask = tokamax_splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]),)
313+ splash_kernel = tokamax_ring_attention_kernel .make_ring_attention (
314+ mask = mask ,
315+ is_mqa = False ,
316+ config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
317+ save_residuals = True ,
318+ ring_axis = "fsdp" ,
309319 )
310320 else :
311321 splash_kernel = splash_attention_kernel .make_splash_mha (
312322 mask = multi_head_mask ,
313323 head_shards = 1 , # the sizes of the axis is sharding over heads
314324 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
315325 block_sizes = block_sizes ,
316- save_residuals = True if attention_kernel == "ring" else False ,
326+ save_residuals = True if "ring" in attention_kernel else False ,
317327 residual_checkpoint_name = residual_checkpoint_name
318328 )
319- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
320329
321- if not mask_padding_tokens :
322- segment_ids = None
323- if attention_kernel in ["flash" , "tokamax_flash" ]:
324- attention_output = vmapped_splash (query , key , value , segment_ids )
330+ if attention_kernel == "tokamax_ring" :
331+ # For tokamax_ring, use the kernel directly without vmap
332+ # The ring attention kernel handles the ring topology internally
333+ if not mask_padding_tokens :
334+ segment_ids = None
335+ attention_output = splash_kernel (
336+ fwd_mask_info = None ,
337+ dkv_mask_info = None ,
338+ q = query ,
339+ k = key ,
340+ v = value ,
341+ segment_ids = segment_ids ,
342+ is_mqa = False ,
343+ config = convert_to_tokamax_splash_config (block_sizes , residual_checkpoint_name = residual_checkpoint_name ),
344+ mask_value = - jnp .inf ,
345+ mask_function = None ,
346+ fwd_mask_sparsity = 1.0 ,
347+ save_residuals = True ,
348+ )
325349 else :
326- if num_fsdp_shards > 1 :
327- out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
328- m = lse .astype (jnp .float32 )
329- l = jnp .exp (lse - m )
330- o = out .astype (jnp .float32 ) * l [..., None ]
350+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
331351
332- perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
352+ if not mask_padding_tokens :
353+ segment_ids = None
354+ if attention_kernel in ["flash" , "tokamax_flash" ]:
355+ attention_output = vmapped_splash (query , key , value , segment_ids )
356+ else :
357+ if num_fsdp_shards > 1 :
358+ out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
359+ m = lse .astype (jnp .float32 )
360+ l = jnp .exp (lse - m )
361+ o = out .astype (jnp .float32 ) * l [..., None ]
333362
334- k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
335- v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
363+ perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
336364
337- def ring_scan_body (carry , _ ):
338- m , l , o , k_current , v_current = carry
339- k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
340- v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
365+ k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
366+ v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
341367
342- out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
368+ def ring_scan_body (carry , _ ):
369+ m , l , o , k_current , v_current = carry
370+ k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
371+ v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
343372
344- m_chunk = lse_chunk .astype (jnp .float32 )
345- m_old = m
346- m = jnp .maximum (m_old , m_chunk )
373+ out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
347374
348- exp_m_diff = jnp .exp (m_old - m )
349- exp_m_chunk_diff = jnp .exp (m_chunk - m )
375+ m_chunk = lse_chunk .astype (jnp .float32 )
376+ m_old = m
377+ m = jnp .maximum (m_old , m_chunk )
350378
351- l = l * exp_m_diff + jnp .exp (lse_chunk - m )
352- o = o * exp_m_diff [..., None ]
353- o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
379+ exp_m_diff = jnp .exp (m_old - m )
380+ exp_m_chunk_diff = jnp .exp (m_chunk - m )
354381
355- # Return the updated state for the next iteration
356- return (m , l , o , k_next , v_next ), None
382+ l = l * exp_m_diff + jnp .exp (lse_chunk - m )
383+ o = o * exp_m_diff [..., None ]
384+ o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
357385
358- initial_carry = ( m , l , o , k1 , v1 )
359- ( m_final , l_final , o_final , _ , _ ), _ = jax . lax . scan ( ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
386+ # Return the updated state for the next iteration
387+ return ( m , l , o , k_next , v_next ), None
360388
361- attention_output = o_final / l_final [..., None ]
362- else :
363- raise ValueError ("ring attention requires fsdp > 1" )
389+ initial_carry = (m , l , o , k1 , v1 )
390+ (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
391+
392+ attention_output = o_final / l_final [..., None ]
393+ else :
394+ raise ValueError ("ring attention requires fsdp > 1" )
364395
365396 return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
366397
@@ -536,7 +567,7 @@ def _apply_attention(
536567 mask_padding_tokens = mask_padding_tokens ,
537568 residual_checkpoint_name = residual_checkpoint_name ,
538569 )
539- elif attention_kernel == "ring" :
570+ elif "ring" in attention_kernel :
540571 return _tpu_flash_attention (
541572 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
542573 mask_padding_tokens = mask_padding_tokens ,
@@ -547,6 +578,7 @@ def _apply_attention(
547578 raise ValueError (f"Unexpected attention kernel { attention_kernel = } ." )
548579
549580
581+
550582def _query_chunk_attention (query , key , value , precision , key_chunk_size : int = 4096 ):
551583 """Multi-head dot product attention with a limited number of queries."""
552584 num_kv , num_heads , k_features = key .shape [- 3 :]
0 commit comments