@@ -113,14 +113,24 @@ def _unflatten_heads(tensor, heads):
113113 return tensor
114114
115115
116- def _reshape_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
116+ def _reshape_data_for_flash (tensor , heads ):
117117 """
118118 Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
119119 Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
120120 blocks is divisible by the number of shards.
121121 """
122122 if tensor .ndim != 4 :
123123 tensor = _unflatten_heads (tensor , heads )
124+ return tensor
125+
126+
127+ def _pad_data_for_flash (tensor , heads , flash_block_size , num_shards : int = 1 ):
128+ """
129+ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
130+ Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
131+ blocks is divisible by the number of shards.
132+ """
133+ tensor = _reshape_data_for_flash (tensor , heads )
124134
125135 # Pad head_dim to 128 if less than that.
126136 kv_size = tensor .shape [- 1 ]
@@ -148,8 +158,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
148158
149159 if kv_size < 128 or seq_len_pad != 0 :
150160 npad = ((0 , 0 ), (0 , 0 ), (0 , seq_len_pad ), (0 , head_dim_pad ))
151- padded_tensor = jnp .pad (tensor , npad )
152- tensor = jax .lax .with_sharding_constraint (padded_tensor , PartitionSpec ("data" , "tensor" , "fsdp" , None ))
161+ tensor = jnp .pad (tensor , npad )
153162
154163 return tensor , kv_size , seq_len
155164
@@ -164,12 +173,14 @@ def _tpu_flash_attention(
164173 axis_names_kv : AxisNames ,
165174 flash_block_sizes : BlockSizes ,
166175 dtype : jnp .dtype = jnp .float32 ,
176+ attention_kernel : str = "flash" ,
167177) -> jax .Array :
168178 """TPU Flash Attention"""
179+
169180 q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
170- # Cross-attention where kv dims are much smaller due to encoder_hidden_states.
171- # If kv seq_len is padded too much, it causes issues in attention calculations.
181+ # This is the case for cross-attn.
172182 if key .shape [1 ] != query .shape [1 ]:
183+ assert key .shape [1 ] % 128 == 0
173184 kv_max_block_size = key .shape [1 ]
174185 else :
175186 kv_max_block_size = q_max_block_size
@@ -186,38 +197,90 @@ def _tpu_flash_attention(
186197 block_q_dq = min (q_max_block_size , query .shape [2 ]),
187198 block_kv_dq = min (kv_max_block_size , query .shape [2 ]),
188199 )
189-
190200 num_fsdp_shards = mesh .shape ["fsdp" ]
191- query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes . block_q , num_fsdp_shards )
192- key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes . block_kv_compute , num_fsdp_shards )
193- value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes . block_kv_compute , num_fsdp_shards )
201+ query = _reshape_data_for_flash (query , heads )
202+ key = _reshape_data_for_flash (key , heads )
203+ value = _reshape_data_for_flash (value , heads )
194204 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
195205 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
196206
197207 @functools .partial (
198208 shard_map .shard_map ,
199209 mesh = mesh ,
200- in_specs = (
201- q_axis_names ,
202- kv_axis_names ,
203- kv_axis_names ,
204- ),
210+ in_specs = (q_axis_names , kv_axis_names , kv_axis_names ),
205211 out_specs = q_axis_names ,
206212 check_rep = False ,
207213 )
208214 def wrap_flash_attention (query , key , value ):
215+
216+ query , kv_size , query_seq_len = _pad_data_for_flash (query , heads , block_sizes .block_q )
217+ key , _ , key_seq_len = _pad_data_for_flash (key , heads , block_sizes .block_kv_compute )
218+ value , _ , _ = _pad_data_for_flash (value , heads , block_sizes .block_kv_compute )
219+
209220 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
210221 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
222+
223+ q_padded_len = query .shape [2 ]
224+ q_indices = jax .lax .broadcasted_iota (jnp .int32 , (q_padded_len ,), 0 )
225+ q_segment_ids = (q_indices < query_seq_len ).astype (jnp .int32 )
226+
227+ kv_padded_len = key .shape [2 ]
228+ kv_indices = jax .lax .broadcasted_iota (jnp .int32 , (kv_padded_len ,), 0 )
229+ kv_segment_ids = (kv_indices < key_seq_len ).astype (jnp .int32 )
230+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
231+
211232 # make_splash_mha is wrapped around shardmap and seq and head is already
212233 # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
213234 splash_kernel = splash_attention_kernel .make_splash_mha (
214235 mask = multi_head_mask ,
215236 head_shards = 1 , # the sizes of the axis is sharding over heads
216237 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
217238 block_sizes = block_sizes ,
239+ save_residuals = True if attention_kernel == "ring" else False ,
218240 )
219- attention_output = jax .vmap (splash_kernel )(query , key , value )
220- return attention_output
241+ vmapped_splash = jax .vmap (splash_kernel , in_axes = (0 , 0 , 0 , None ))
242+
243+ if attention_kernel == "flash" :
244+ attention_output = vmapped_splash (query , key , value , segment_ids )
245+ else :
246+ if num_fsdp_shards > 1 :
247+ out , (lse ,) = vmapped_splash (query , key , value , segment_ids )
248+ m = lse .astype (jnp .float32 )
249+ l = jnp .exp (lse - m )
250+ o = out .astype (jnp .float32 ) * l [..., None ]
251+
252+ perm = [(j , (j + 1 ) % num_fsdp_shards ) for j in range (num_fsdp_shards )]
253+
254+ k1 = jax .lax .ppermute (key , axis_name = "fsdp" , perm = perm )
255+ v1 = jax .lax .ppermute (value , axis_name = "fsdp" , perm = perm )
256+
257+ def ring_scan_body (carry , _ ):
258+ m , l , o , k_current , v_current = carry
259+ k_next = jax .lax .ppermute (k_current , axis_name = "fsdp" , perm = perm )
260+ v_next = jax .lax .ppermute (v_current , axis_name = "fsdp" , perm = perm )
261+
262+ out_chunk , (lse_chunk ,) = vmapped_splash (query , k_current , v_current , segment_ids )
263+
264+ m_chunk = lse_chunk .astype (jnp .float32 )
265+ m_old = m
266+ m = jnp .maximum (m_old , m_chunk )
267+
268+ exp_m_diff = jnp .exp (m_old - m )
269+ exp_m_chunk_diff = jnp .exp (m_chunk - m )
270+
271+ l = l * exp_m_diff + jnp .exp (lse_chunk - m )
272+ o = o * exp_m_diff [..., None ]
273+ o += exp_m_chunk_diff [..., None ] * out_chunk .astype (jnp .float32 )
274+
275+ # Return the updated state for the next iteration
276+ return (m , l , o , k_next , v_next ), None
277+
278+ initial_carry = (m , l , o , k1 , v1 )
279+ (m_final , l_final , o_final , _ , _ ), _ = jax .lax .scan (ring_scan_body , initial_carry , None , length = num_fsdp_shards - 1 )
280+
281+ attention_output = o_final / l_final [..., None ]
282+
283+ return attention_output [:, :, :query_seq_len , :kv_size ].astype (query .dtype )
221284
222285 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
223286 # This warning might show up when doing model eval for example, when calculating model flops
@@ -228,7 +291,6 @@ def wrap_flash_attention(query, key, value):
228291 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
229292 )
230293 x = wrap_flash_attention (query , key , value )
231- x = x [:, :, :query_seq_len , :kv_size ]
232294 x = _reshape_heads_to_head_dim (x )
233295
234296 return x
@@ -379,6 +441,10 @@ def _apply_attention(
379441 return _tpu_flash_attention (
380442 query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
381443 )
444+ elif attention_kernel == "ring" :
445+ return _tpu_flash_attention (
446+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel
447+ )
382448 elif attention_kernel == "cudnn_flash_te" :
383449 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
384450 else :
0 commit comments