@@ -195,7 +195,7 @@ def _tpu_flash_attention(
195195 block_q_dq = min (q_max_block_size , query .shape [2 ]),
196196 block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
197197 )
198-
198+ num_fsdp_shards = mesh . shape [ "fsdp" ]
199199 query = _reshape_data_for_flash (query , heads )
200200 key = _reshape_data_for_flash (key , heads )
201201 value = _reshape_data_for_flash (value , heads )
@@ -218,9 +218,7 @@ def wrap_flash_attention(query, key, value):
218218 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
219219 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
220220 q_segment_ids = jnp .where (jnp .arange (query .shape [2 ]) < query_seq_len , 1 , 0 )
221- q_segment_ids = jnp .broadcast_to (q_segment_ids , (query .shape [0 ], q_segment_ids .shape [0 ]))
222221 kv_segment_ids = jnp .where (jnp .arange (key .shape [2 ]) < key_seq_len , 1 , 0 )
223- kv_segment_ids = jnp .broadcast_to (kv_segment_ids , (query .shape [0 ], kv_segment_ids .shape [0 ]))
224222 segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
225223
226224 # make_splash_mha is wrapped around shardmap and seq and head is already
@@ -230,9 +228,51 @@ def wrap_flash_attention(query, key, value):
230228 head_shards = 1 , # the sizes of the axis is sharding over heads
231229 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
232230 block_sizes = block_sizes ,
231+ save_residuals = True
233232 )
234- attention_output = jax .vmap (splash_kernel )(query , key , value , segment_ids = segment_ids )
235- return attention_output [:,:,:query_seq_len ,:kv_size ]
233+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 ,0 ,0 , None ))
234+
235+ def ring_scan_body (carry , _ ):
236+ m , l , o , k_current , v_current = carry
237+ perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
238+ k_next = jax .lax .ppermute (k_current , axis_name = 'fsdp' , perm = perm )
239+ v_next = jax .lax .ppermute (v_current , axis_name = 'fsdp' , perm = perm )
240+
241+ out_chunk , (lse_chunk ,) = vmapped_splash (
242+ query , k_current , v_current , segment_ids
243+ )
244+
245+ m_chunk = lse_chunk .astype (jnp .float32 )
246+ m_old = m
247+ m = jnp .maximum (m_old , m_chunk )
248+
249+ exp_m_diff = jnp .exp (m_old - m )
250+ exp_m_chunk_diff = jnp .exp (m_chunk - m )
251+
252+ l = l * exp_m_diff + jnp .exp (lse_chunk - m )
253+ o = o * exp_m_diff [..., None ]
254+ o += (exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 ))
255+
256+ # Return the updated state for the next iteration
257+ return (m , l , o , k_next , v_next ), None
258+
259+ lse_shape = query .shape [:- 1 ]
260+ m_init = jnp .full (lse_shape , - jnp .inf , dtype = jnp .float32 )
261+ l_init = jnp .zeros (lse_shape , dtype = jnp .float32 )
262+ o_init = jnp .zeros_like (query , dtype = jnp .float32 )
263+
264+ initial_carry = (m_init , l_init , o_init , key , value )
265+
266+ (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (
267+ ring_scan_body ,
268+ initial_carry ,
269+ None ,
270+ length = num_fsdp_shards
271+ )
272+
273+ attention_output = o_final / l_final [..., None ]
274+
275+ return attention_output [:,:,:query_seq_len ,:kv_size ].astype (query .dtype )
236276
237277 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
238278 # This warning might show up when doing model eval for example, when calculating model flops
0 commit comments