@@ -188,8 +188,8 @@ def _tpu_flash_attention(
188188 )
189189
190190 num_fsdp_shards = mesh .shape ["fsdp" ]
191- query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q , num_fsdp_shards )
192- key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute , num_fsdp_shards )
191+ query , kv_size , original_query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q , num_fsdp_shards )
192+ key , _ , original_key_seq_len = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute , num_fsdp_shards )
193193 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute , num_fsdp_shards )
194194 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
195195 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
@@ -206,6 +206,10 @@ def _tpu_flash_attention(
206206 check_rep = False ,
207207 )
208208 def wrap_flash_attention (query , key , value ):
209+ jax .debug .print ("query.shape: {x}" , x = query .shape )
210+ jax .debug .print ("key.shape: {x}" , x = key .shape )
211+ jax .debug .print ("value.shape: {x}" , x = value .shape )
212+
209213 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
210214 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
211215 # make_splash_mha is wrapped around shardmap and seq and head is already
@@ -215,8 +219,45 @@ def wrap_flash_attention(query, key, value):
215219 head_shards = 1 , # the sizes of the axis is sharding over heads
216220 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
217221 block_sizes = block_sizes ,
222+ save_residuals = True
218223 )
219- attention_output = jax .vmap (splash_kernel )(query , key , value )
224+ out , (lse ,) = jax .vmap (splash_kernel )(query , key , value )
225+ #breakpoint()
226+ m = lse .astype (jnp .float32 )
227+ l = jnp .exp (lse .astype (jnp .float32 ) - m )
228+ o = out .astype (jnp .float32 ) * l [..., None ]
229+
230+ k_ring = key
231+ v_ring = value
232+
233+ for i in range (1 , num_fsdp_shards ):
234+ k_ring = jax .lax .ppermute (k_ring , axis_name = 'fsdp' , perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )])
235+ v_ring = jax .lax .ppermute (v_ring , axis_name = 'fsdp' , perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )])
236+
237+ out_chunk , (lse_chunk ,) = jax .vmap (splash_kernel )(query , k_ring , v_ring )
238+ m_chunk = lse_chunk .astype (jnp .float32 )
239+ p_chunk = jnp .exp (lse_chunk .astype (jnp .float32 ) - m_chunk )
240+
241+ m_new = jnp .maximum (m , m_chunk )
242+
243+ l = l * jnp .exp (m - m_new )
244+ p_chunk_rescaled = p_chunk * jnp .exp (m_chunk - m_new )
245+
246+ l_new = l + p_chunk_rescaled
247+
248+ o = o * jnp .exp (m - m_new )[..., None ]
249+ o += p_chunk_rescaled [..., None ] * out_chunk
250+
251+ m = m_new
252+ l = l_new
253+ jax .debug .print ("Loop {i}: max(m)={m_max}, max(l)={l_max}, max(o)={o_max}" ,
254+ i = i ,
255+ m_max = m .max (),
256+ l_max = l .max (),
257+ o_max = o .max ())
258+
259+ attention_output = o / l [..., None ]
260+
220261 return attention_output
221262
222263 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
@@ -228,7 +269,7 @@ def wrap_flash_attention(query, key, value):
228269 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
229270 )
230271 x = wrap_flash_attention (query , key , value )
231- x = x [:, :, :query_seq_len , :kv_size ]
272+ x = x [:, :, :original_query_seq_len , :kv_size ]
232273 x = _reshape_heads_to_head_dim (x )
233274
234275 return x
0 commit comments