1818import flax .linen as nn
1919from flax import nnx
2020import jax
21- from jax .sharding import PartitionSpec
21+ from jax .experimental .pjit import pjit
22+ from jax .sharding import NamedSharding , PartitionSpec
2223import jax .numpy as jnp
2324from jax .experimental import shard_map
2425from jax .experimental .pallas .ops .tpu .splash_attention import splash_attention_mask
@@ -153,6 +154,198 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
153154 return tensor , kv_size , seq_len
154155
155156
157+ def _tpu_ring_flash_attention_v1 (
158+ query : jax .Array ,
159+ key : jax .Array ,
160+ value : jax .Array ,
161+ heads : int ,
162+ mesh : Mesh ,
163+ axis_names_q : AxisNames ,
164+ axis_names_kv : AxisNames ,
165+ flash_block_sizes : BlockSizes ,
166+ dtype : jnp .dtype = jnp .float32 ,
167+ ) -> jax .Array :
168+ """TPU Ring Flash Attention with correct padding, transposition, and sharding."""
169+ from ringattention import ringattention
170+ from einops import rearrange
171+
172+ max_block_size = 1024 if dtype == jnp .bfloat16 else 512
173+ blockwise_kwargs = {
174+ # FullMask
175+ "causal_block_size" : None ,
176+ "deterministic" : True ,
177+ "dropout_rng" : None ,
178+ "attn_pdrop" : 0.0 ,
179+ "policy" : jax .checkpoint_policies .nothing_saveable ,
180+ "dtype" : dtype ,
181+ "precision" : None ,
182+ "prevent_cse" : True ,
183+ }
184+ if flash_block_sizes :
185+ blockwise_kwargs ["query_chunk_size" ] = flash_block_sizes .block_q
186+ blockwise_kwargs ["key_chunk_size" ] = flash_block_sizes .block_kv
187+ else :
188+ blockwise_kwargs ["query_chunk_size" ] = min (max_block_size , query .shape [2 ])
189+ blockwise_kwargs ["key_chunk_size" ] = min (max_block_size , key .shape [2 ])
190+
191+ num_fsdp_shards = mesh .shape ["fsdp" ]
192+ query_padded , kv_size , query_seq_len = _reshape_data_for_flash (
193+ query , heads , blockwise_kwargs ["query_chunk_size" ], num_fsdp_shards
194+ )
195+ key_padded , _ , _ = _reshape_data_for_flash (
196+ key , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
197+ )
198+ value_padded , _ , _ = _reshape_data_for_flash (
199+ value , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
200+ )
201+
202+ # (b, s, h, d) shape expected
203+ query_t = rearrange (query_padded , 'b h s d -> b s h d' )
204+ key_t = rearrange (key_padded , 'b h s d -> b s h d' )
205+ value_t = rearrange (value_padded , 'b h s d -> b s h d' )
206+
207+ q_axis_names_physical = nn .logical_to_mesh_axes (axis_names_q )
208+ transposed_spec = PartitionSpec (
209+ q_axis_names_physical [0 ], # BATCH -> 'data'
210+ q_axis_names_physical [2 ], # LENGTH -> 'fsdp'
211+ q_axis_names_physical [1 ], # HEAD -> 'tensor'
212+ q_axis_names_physical [3 ], # D_KV -> None
213+ )
214+ ring_axis_name = q_axis_names_physical [2 ]
215+
216+ ring_attention_sharded = shard_map .shard_map (
217+ functools .partial (
218+ ringattention ,
219+ attn_bias = None ,
220+ segment_ids = None ,
221+ axis_name = ring_axis_name ,
222+ float32_logits = True ,
223+ cache_idx = None ,
224+ blockwise_kwargs = blockwise_kwargs ,
225+ ),
226+ mesh = mesh ,
227+ in_specs = (transposed_spec , transposed_spec , transposed_spec ),
228+ # output is (b,s,h,d)
229+ out_specs = transposed_spec ,
230+ check_rep = False ,
231+ )
232+
233+ attention_output = ring_attention_sharded (query_t , key_t , value_t )
234+
235+ attention_output_t = rearrange (attention_output , 'b s h d -> b h s d' )
236+
237+ # Unpad the output to get back original sequence
238+ attention_output_cropped = attention_output_t [:, :, :query_seq_len , :kv_size ]
239+
240+ # Reshape to (b, s, h*d)
241+ final_output = _reshape_heads_to_head_dim (attention_output_cropped )
242+
243+ return final_output
244+
245+ def _tpu_ring_flash_attention (
246+ query : jax .Array ,
247+ key : jax .Array ,
248+ value : jax .Array ,
249+ heads : int ,
250+ mesh : Mesh ,
251+ axis_names_q : AxisNames ,
252+ axis_names_kv : AxisNames ,
253+ flash_block_sizes : BlockSizes ,
254+ dtype : jnp .dtype = jnp .float32 ,
255+ usp_degree : Optional [int ] = 1 ,
256+ ) -> jax .Array :
257+ """TPU Ring/USP Flash Attention with correct padding, transposition, and sharding."""
258+ from ringattention import ringattention
259+
260+ max_block_size = 1024 if dtype == jnp .bfloat16 else 512
261+ blockwise_kwargs = {
262+ #FullMask
263+ "causal_block_size" : None ,
264+ "deterministic" : True ,
265+ "dropout_rng" : None ,
266+ "attn_pdrop" : 0.0 ,
267+ "policy" : jax .checkpoint_policies .nothing_saveable ,
268+ "dtype" : dtype ,
269+ "precision" : None ,
270+ "prevent_cse" : True ,
271+ }
272+ if flash_block_sizes :
273+ blockwise_kwargs ["query_chunk_size" ] = flash_block_sizes .block_q
274+ blockwise_kwargs ["key_chunk_size" ] = flash_block_sizes .block_kv
275+ else :
276+ # Get seq_len from shape[2] of the original (b, h, s, d) tensor
277+ blockwise_kwargs ["query_chunk_size" ] = min (max_block_size , query .shape [2 ])
278+ blockwise_kwargs ["key_chunk_size" ] = min (max_block_size , key .shape [2 ])
279+
280+ # Pad sequence to be divisible by block size
281+ num_fsdp_shards = mesh .shape ["fsdp" ]
282+ query_padded , kv_size , query_seq_len = _reshape_data_for_flash (
283+ query , heads , blockwise_kwargs ["query_chunk_size" ], num_fsdp_shards
284+ )
285+ key_padded , _ , _ = _reshape_data_for_flash (
286+ key , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
287+ )
288+ value_padded , _ , _ = _reshape_data_for_flash (
289+ value , heads , blockwise_kwargs ["key_chunk_size" ], num_fsdp_shards
290+ )
291+
292+ num_fsdp_devices = mesh .shape ["fsdp" ]
293+ if num_fsdp_devices % usp_degree != 0 :
294+ raise ValueError ("fsdp axis size must be divisible by usp_degree" )
295+ ring_degree = num_fsdp_devices // usp_degree
296+ logical_mesh_shape = (mesh .shape ['data' ], usp_degree , ring_degree , mesh .shape ["tensor" ])
297+ reshaped_devices = mesh .devices .reshape (logical_mesh_shape )
298+ logical_mesh = Mesh (reshaped_devices , ('data' , 'usp' , 'ring' , 'tensor' ))
299+
300+ def _kernel (q_padded , k_padded , v_padded ):
301+ # All-to-All Swap NO-OP if usp_degree=1)
302+ q_swapped = jax .lax .all_to_all (q_padded , 'usp' , split_axis = 1 , concat_axis = 2 , tiled = True )
303+ k_swapped = jax .lax .all_to_all (k_padded , 'usp' , split_axis = 1 , concat_axis = 2 , tiled = True )
304+ v_swapped = jax .lax .all_to_all (v_padded , 'usp' , split_axis = 1 , concat_axis = 2 , tiled = True )
305+
306+ q_t = rearrange (q_swapped , 'b h s d -> b s h d' )
307+ k_t = rearrange (k_swapped , 'b h s d -> b s h d' )
308+ v_t = rearrange (v_swapped , 'b h s d -> b s h d' )
309+
310+ attn_out_t = ringattention (q_t , k_t , v_t , axis_name = 'ring' ,
311+ blockwise_kwargs = blockwise_kwargs ,
312+ attn_bias = None ,
313+ segment_ids = None ,
314+ cache_idx = None ,
315+ float32_logits = True ,
316+ )
317+ return attn_out_t
318+
319+
320+ initial_spec = PartitionSpec ('data' , 'tensor' , ('usp' , 'ring' ), None )
321+
322+ output_spec = PartitionSpec ('data' , 'ring' , ('usp' , 'tensor' ), None )
323+
324+ sharded_attention_fn = shard_map .shard_map (
325+ _kernel ,
326+ mesh = logical_mesh ,
327+ in_specs = (initial_spec , initial_spec , initial_spec ),
328+ out_specs = output_spec ,
329+ check_rep = False
330+ )
331+
332+ attention_output_t = sharded_attention_fn (query_padded , key_padded , value_padded )
333+
334+ attention_output = rearrange (attention_output_t , 'b s h d -> b h s d' )
335+
336+ # Back to original sharding using original physical mesh
337+ physical_spec = PartitionSpec ('data' , 'tensor' , 'fsdp' , None )
338+ attention_output_resharded = jax .lax .with_sharding_constraint (
339+ attention_output ,
340+ NamedSharding (mesh , physical_spec )
341+ )
342+
343+ # Unpad sequence and reshape to head x head_dim
344+ attention_output_cropped = attention_output_resharded [:, :, :query_seq_len , :kv_size ]
345+ final_output = _reshape_heads_to_head_dim (attention_output_cropped )
346+
347+ return final_output
348+
156349def _tpu_flash_attention (
157350 query : jax .Array ,
158351 key : jax .Array ,
@@ -372,7 +565,7 @@ def _apply_attention(
372565 seq_len_idx = 1
373566 if query .ndim == 4 :
374567 seq_len_idx = 2
375- if attention_kernel == "flash" :
568+ if attention_kernel in [ "flash" , "ring_flash" ] :
376569 can_use_flash_attention = (
377570 query .shape [seq_len_idx ] >= flash_min_seq_length
378571 and key .shape [seq_len_idx ] >= flash_min_seq_length
@@ -388,6 +581,11 @@ def _apply_attention(
388581 return _tpu_flash_attention (
389582 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
390583 )
584+ elif attention_kernel == "ring_flash" :
585+ max_logging .log ("USING RING ATTENTION" )
586+ return _tpu_ring_flash_attention_v1 (
587+ query , key , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
588+ )
391589 elif attention_kernel == "cudnn_flash_te" :
392590 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
393591 else :
0 commit comments