1818import flax .linen as nn
1919from flax import nnx
2020import jax
21- from jax .sharding import PartitionSpec
21+ from jax .experimental .pjit import pjit
22+ from jax .sharding import NamedSharding , PartitionSpec
2223import jax .numpy as jnp
2324from jax .experimental import shard_map
2425from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
@@ -153,6 +154,216 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
153154 return tensor , kv_size , seq_len
154155
155156
157+ def _tpu_ring_flash_attention_v1 (
158+ query : jax .Array ,
159+ key : jax .Array ,
160+ value : jax .Array ,
161+ heads : int ,
162+ mesh : Mesh ,
163+ axis_names_q : AxisNames ,
164+ axis_names_kv : AxisNames ,
165+ flash_block_sizes : BlockSizes ,
166+ dtype : jnp .dtype = jnp .float32 ,
167+ ) -> jax .Array :
168+ """TPU Ring Flash Attention with correct padding, transposition, and sharding."""
169+ from ringattention import ringattention
170+ from einops import rearrange
171+
172+ # --- Step 1: Initialize Block Sizes ---
173+ max_block_size = 1024 if dtype == jnp .bfloat16 else 512
174+ blockwise_kwargs = {
175+ # CRITICAL: Ensures non-causal attention to match FullMask
176+ "causal_block_size" : None ,
177+ "deterministic" : True ,
178+ "dropout_rng" : None ,
179+ "attn_pdrop" : 0.0 ,
180+ "policy" : jax .checkpoint_policies .nothing_saveable ,
181+ "dtype" : dtype ,
182+ "precision" : None ,
183+ "prevent_cse" : True ,
184+ }
185+ if flash_block_sizes :
186+ blockwise_kwargs ["query_chunk_size" ] = flash_block_sizes .block_q
187+ blockwise_kwargs ["key_chunk_size" ] = flash_block_sizes .block_kv
188+ else :
189+ # Get seq_len from shape[2] of the original (b, h, s, d) tensor
190+ blockwise_kwargs ["query_chunk_size" ] = min (max_block_size , query .shape [2 ])
191+ blockwise_kwargs ["key_chunk_size" ] = min (max_block_size , key .shape [2 ])
192+
193+ # --- Step 2: Pad and Preprocess Tensors (CRITICAL) ---
194+ num_fsdp_shards = mesh .shape ["fsdp" ]
195+ query_padded , kv_size , query_seq_len = _reshape_data_for_flash (
196+ query , heads , blockwise_kwargs ["query_chunk_size" ], num_fsdp_shards
197+ )
198+ key_padded , _ , _ = _reshape_data_for_flash (
199+ key , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
200+ )
201+ value_padded , _ , _ = _reshape_data_for_flash (
202+ value , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
203+ )
204+
205+ # Transpose the *padded* inputs to the (b, s, h, d) format the library expects
206+ query_t = rearrange (query_padded , 'b h s d -> b s h d' )
207+ key_t = rearrange (key_padded , 'b h s d -> b s h d' )
208+ value_t = rearrange (value_padded , 'b h s d -> b s h d' )
209+
210+ # --- Step 3: Define Sharding and the Sharded Function ---
211+ # The PartitionSpec must match the transposed (b, s, h, d) layout.
212+ q_axis_names_physical = nn .logical_to_mesh_axes (axis_names_q )
213+ transposed_spec = PartitionSpec (
214+ q_axis_names_physical [0 ], # BATCH -> 'data'
215+ q_axis_names_physical [2 ], # LENGTH -> 'fsdp'
216+ q_axis_names_physical [1 ], # HEAD -> 'tensor'
217+ q_axis_names_physical [3 ], # D_KV -> None
218+ )
219+ ring_axis_name = q_axis_names_physical [2 ]
220+
221+ ring_attention_sharded = shard_map .shard_map (
222+ functools .partial (
223+ ringattention ,
224+ attn_bias = None ,
225+ segment_ids = None ,
226+ axis_name = ring_axis_name ,
227+ float32_logits = True ,
228+ cache_idx = None ,
229+ blockwise_kwargs = blockwise_kwargs ,
230+ ),
231+ mesh = mesh ,
232+ in_specs = (transposed_spec , transposed_spec , transposed_spec ),
233+ # The library's output is (b,s,h,d), so its sharding matches the transposed input
234+ out_specs = transposed_spec ,
235+ check_rep = False ,
236+ )
237+
238+ # --- Step 4: Execute and Post-process ---
239+ attention_output = ring_attention_sharded (query_t , key_t , value_t )
240+
241+ # Transpose the output from (b, s, h, d) back to the original (b, h, s, d) convention
242+ attention_output_t = rearrange (attention_output , 'b s h d -> b h s d' )
243+
244+ # Unpad the output to match the original sequence and head dimension
245+ attention_output_cropped = attention_output_t [:, :, :query_seq_len , :kv_size ]
246+
247+ # Reshape to the final (b, s, h*d) format, just like in _tpu_flash_attention
248+ final_output = _reshape_heads_to_head_dim (attention_output_cropped )
249+
250+ return final_output
251+
252+ def _tpu_ring_flash_attention (
253+ query : jax .Array ,
254+ key : jax .Array ,
255+ value : jax .Array ,
256+ heads : int ,
257+ mesh : Mesh ,
258+ axis_names_q : AxisNames ,
259+ axis_names_kv : AxisNames ,
260+ flash_block_sizes : BlockSizes ,
261+ dtype : jnp .dtype = jnp .float32 ,
262+ usp_degree : Optional [int ] = 1 ,
263+ ) -> jax .Array :
264+ """TPU Ring Flash Attention with correct padding, transposition, and sharding."""
265+ from ringattention import ringattention
266+
267+ # --- Step 1: Initialize Block Sizes ---
268+ max_block_size = 1024 if dtype == jnp .bfloat16 else 512
269+ blockwise_kwargs = {
270+ # CRITICAL: Ensures non-causal attention to match FullMask
271+ "causal_block_size" : None ,
272+ "deterministic" : True ,
273+ "dropout_rng" : None ,
274+ "attn_pdrop" : 0.0 ,
275+ "policy" : jax .checkpoint_policies .nothing_saveable ,
276+ "dtype" : dtype ,
277+ "precision" : None ,
278+ "prevent_cse" : True ,
279+ }
280+ if flash_block_sizes :
281+ blockwise_kwargs ["query_chunk_size" ] = flash_block_sizes .block_q
282+ blockwise_kwargs ["key_chunk_size" ] = flash_block_sizes .block_kv
283+ else :
284+ # Get seq_len from shape[2] of the original (b, h, s, d) tensor
285+ blockwise_kwargs ["query_chunk_size" ] = min (max_block_size , query .shape [2 ])
286+ blockwise_kwargs ["key_chunk_size" ] = min (max_block_size , key .shape [2 ])
287+
288+ # --- Step 2: Pad and Preprocess Tensors (CRITICAL) ---
289+ num_fsdp_shards = mesh .shape ["fsdp" ]
290+ query_padded , kv_size , query_seq_len = _reshape_data_for_flash (
291+ query , heads , blockwise_kwargs ["query_chunk_size" ], num_fsdp_shards
292+ )
293+ key_padded , _ , _ = _reshape_data_for_flash (
294+ key , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
295+ )
296+ value_padded , _ , _ = _reshape_data_for_flash (
297+ value , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
298+ )
299+
300+ # --- Step 3: Define Sharding and the Sharded Function ---
301+ # The PartitionSpec must match the transposed (b, s, h, d) layout.
302+ num_fsdp_devices = mesh .shape ["fsdp" ]
303+ if num_fsdp_devices % usp_degree != 0 :
304+ raise ValueError ("fsdp axis size must be divisible by usp_degree" )
305+ ring_degree = num_fsdp_devices // usp_degree
306+ logical_mesh_shape = (mesh .shape ['data' ], usp_degree , ring_degree , mesh .shape ["tensor" ])
307+ reshaped_devices = mesh .devices .reshape (logical_mesh_shape )
308+ logical_mesh = Mesh (reshaped_devices , ('data' , 'usp' , 'ring' , 'tensor' ))
309+
310+ # 2.2: Define the KERNEL containing all sharded logic
311+ def _kernel (q_padded , k_padded , v_padded ):
312+ # Step A: All-to-All Swap (This is a NO-OP if usp_degree=1)
313+ q_swapped = jax .lax .all_to_all (q_padded , 'usp' , split_axis = 1 , concat_axis = 2 , tiled = True )
314+ k_swapped = jax .lax .all_to_all (k_padded , 'usp' , split_axis = 1 , concat_axis = 2 , tiled = True )
315+ v_swapped = jax .lax .all_to_all (v_padded , 'usp' , split_axis = 1 , concat_axis = 2 , tiled = True )
316+
317+ # Step B: Transpose the *swapped* tensors for the library
318+ q_t = rearrange (q_swapped , 'b h s d -> b s h d' )
319+ k_t = rearrange (k_swapped , 'b h s d -> b s h d' )
320+ v_t = rearrange (v_swapped , 'b h s d -> b s h d' )
321+
322+ # Step C: Ring Attention (Always communicates along the 'ring' axis)
323+ attn_out_t = ringattention (q_t , k_t , v_t , axis_name = 'ring' ,
324+ blockwise_kwargs = blockwise_kwargs ,
325+ attn_bias = None ,
326+ segment_ids = None ,
327+ cache_idx = None ,
328+ float32_logits = True ,
329+ )
330+ return attn_out_t
331+
332+ # 2.3: Define sharding for the inputs to the kernel
333+ initial_spec = PartitionSpec ('data' , 'tensor' , ('usp' , 'ring' ), None )
334+
335+ # 2.4: Define sharding for the output of the kernel (which is transposed)
336+ output_spec = PartitionSpec ('data' , 'ring' , ('usp' , 'tensor' ), None )
337+
338+ # 2.5: Create and call the sharded function
339+ sharded_attention_fn = shard_map .shard_map (
340+ _kernel ,
341+ mesh = logical_mesh ,
342+ in_specs = (initial_spec , initial_spec , initial_spec ),
343+ out_specs = output_spec ,
344+ check_rep = False
345+ )
346+
347+ attention_output_t = sharded_attention_fn (query_padded , key_padded , value_padded )
348+
349+ # --- Step 3: Reshard back to Physical Layout and Post-process ---
350+
351+ # 3.1: Transpose back to (b, h, s, d)
352+ attention_output = rearrange (attention_output_t , 'b s h d -> b h s d' )
353+
354+ # 3.2: Define the target sharding using the ORIGINAL physical mesh
355+ physical_spec = PartitionSpec ('data' , 'tensor' , 'fsdp' , None )
356+ attention_output_resharded = jax .lax .with_sharding_constraint (
357+ attention_output ,
358+ NamedSharding (mesh , physical_spec )
359+ )
360+
361+ # 3.3: Unpad and reshape
362+ attention_output_cropped = attention_output_resharded [:, :, :query_seq_len , :kv_size ]
363+ final_output = _reshape_heads_to_head_dim (attention_output_cropped )
364+
365+ return final_output
366+
156367def _tpu_flash_attention (
157368 query : jax .Array ,
158369 key : jax .Array ,
@@ -372,7 +583,7 @@ def _apply_attention(
372583 seq_len_idx = 1
373584 if query .ndim == 4 :
374585 seq_len_idx = 2
375- if attention_kernel == "flash" :
586+ if attention_kernel in [ "flash" , "ring_flash" ] :
376587 can_use_flash_attention = (
377588 query .shape [seq_len_idx ] >= flash_min_seq_length
378589 and key .shape [seq_len_idx ] >= flash_min_seq_length
@@ -388,6 +599,11 @@ def _apply_attention(
388599 return _tpu_flash_attention (
389600 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
390601 )
602+ elif attention_kernel == "ring_flash" :
603+ max_logging .log ("USING RING ATTENTION" )
604+ return _tpu_ring_flash_attention_v1 (
605+ query , key , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
606+ )
391607 elif attention_kernel == "cudnn_flash_te" :
392608 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
393609 else :
0 commit comments