@@ -112,6 +112,7 @@ def _unflatten_heads(tensor, heads):
112112 tensor = jnp .transpose (tensor , (0 , 2 , 1 , 3 ))
113113 return tensor
114114
115+
115116def _reshape_data_for_flash (tensor , heads ):
116117 """
117118 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
@@ -122,6 +123,7 @@ def _reshape_data_for_flash(tensor, heads):
122123 tensor = _unflatten_heads (tensor , heads )
123124 return tensor
124125
126+
125127def _pad_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
126128 """
127129 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
@@ -171,6 +173,7 @@ def _tpu_flash_attention(
171173 axis_names_kv : AxisNames ,
172174 flash_block_sizes : BlockSizes ,
173175 dtype : jnp .dtype = jnp .float32 ,
176+ attention_kernel : str = "flash" ,
174177) -> jax .Array :
175178 """TPU Flash Attention"""
176179
@@ -179,7 +182,6 @@ def _tpu_flash_attention(
179182 if key .shape [1 ] != query .shape [1 ]:
180183 assert key .shape [1 ] % 128 == 0
181184 kv_max_block_size = key .shape [1 ]
182- #q_max_block_size = kv_max_block_size
183185 else :
184186 kv_max_block_size = q_max_block_size
185187 if flash_block_sizes :
@@ -217,8 +219,14 @@ def wrap_flash_attention(query, key, value):
217219
218220 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
219221 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
220- q_segment_ids = jnp .where (jnp .arange (query .shape [2 ]) < query_seq_len , 1 , 0 )
221- kv_segment_ids = jnp .where (jnp .arange (key .shape [2 ]) < key_seq_len , 1 , 0 )
222+
223+ q_padded_len = query .shape [2 ]
224+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
225+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
226+
227+ kv_padded_len = key .shape [2 ]
228+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
229+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
222230 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
223231
224232 # make_splash_mha is wrapped around shardmap and seq and head is already
@@ -228,51 +236,51 @@ def wrap_flash_attention(query, key, value):
228236 head_shards = 1 , # the sizes of the axis is sharding over heads
229237 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
230238 block_sizes = block_sizes ,
231- save_residuals = True
239+ save_residuals = True if attention_kernel == "ring" else False ,
232240 )
233- vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 ,0 , 0 , None ))
241+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
234242
235- def ring_scan_body (carry , _ ):
236- m , l , o , k_current , v_current = carry
237- perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
238- k_next = jax .lax .ppermute (k_current , axis_name = 'fsdp' , perm = perm )
239- v_next = jax .lax .ppermute (v_current , axis_name = 'fsdp' , perm = perm )
243+ if attention_kernel == "flash" :
244+ attention_output = vmapped_splash (query , key , value , segment_ids )
245+ else :
246+ if num_fsdp_shards > 1 :
247+ out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
248+ m = lse .astype (jnp .float32 )
249+ l = jnp .exp (lse - m )
250+ o = out .astype (jnp .float32 ) * l [..., None ]
240251
241- out_chunk , (lse_chunk ,) = vmapped_splash (
242- query , k_current , v_current , segment_ids
243- )
252+ perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
244253
245- m_chunk = lse_chunk .astype (jnp .float32 )
246- m_old = m
247- m = jnp .maximum (m_old , m_chunk )
248-
249- exp_m_diff = jnp .exp (m_old - m )
250- exp_m_chunk_diff = jnp .exp (m_chunk - m )
254+ k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
255+ v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
251256
252- l = l * exp_m_diff + jnp .exp (lse_chunk - m )
253- o = o * exp_m_diff [..., None ]
254- o += (exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 ))
257+ def ring_scan_body (carry , _ ):
258+ m , l , o , k_current , v_current = carry
259+ k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
260+ v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
255261
256- # Return the updated state for the next iteration
257- return (m , l , o , k_next , v_next ), None
262+ out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
258263
259- lse_shape = query .shape [:- 1 ]
260- m_init = jnp .full (lse_shape , - jnp .inf , dtype = jnp .float32 )
261- l_init = jnp .zeros (lse_shape , dtype = jnp .float32 )
262- o_init = jnp .zeros_like (query , dtype = jnp .float32 )
264+ m_chunk = lse_chunk .astype (jnp .float32 )
265+ m_old = m
266+ m = jnp .maximum (m_old , m_chunk )
263267
264- initial_carry = (m_init , l_init , o_init , key , value )
268+ exp_m_diff = jnp .exp (m_old - m )
269+ exp_m_chunk_diff = jnp .exp (m_chunk - m )
265270
266- (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (
267- ring_scan_body ,
268- initial_carry ,
269- None ,
270- length = num_fsdp_shards
271- )
271+ l = l * exp_m_diff + jnp .exp (lse_chunk - m )
272+ o = o * exp_m_diff [..., None ]
273+ o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
274+
275+ # Return the updated state for the next iteration
276+ return (m , l , o , k_next , v_next ), None
277+
278+ initial_carry = (m , l , o , k1 , v1 )
279+ (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
272280
273- attention_output = o_final / l_final [..., None ]
281+ attention_output = o_final / l_final [..., None ]
274282
275- return attention_output [:,:, :query_seq_len ,:kv_size ].astype (query .dtype )
283+ return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
276284
277285 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
278286 # This warning might show up when doing model eval for example, when calculating model flops
@@ -433,6 +441,10 @@ def _apply_attention(
433441 return _tpu_flash_attention (
434442 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
435443 )
444+ elif attention_kernel == "ring" :
445+ return _tpu_flash_attention (
446+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel
447+ )
436448 elif attention_kernel == "cudnn_flash_te" :
437449 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
438450 else :
0 commit comments