@@ -189,34 +189,39 @@ def _tpu_flash_attention(
189189
190190 num_fsdp_shards = mesh .shape ["fsdp" ]
191191 query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q , num_fsdp_shards )
192- key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute , num_fsdp_shards )
192+ key , _ , key_seq_len = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute , num_fsdp_shards )
193193 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute , num_fsdp_shards )
194194 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
195195 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
196196
197+ # To only attend to non-padded tokens.
198+ segment_axis_names_q = nn .logical_to_mesh_axes ((BATCH , LENGTH ))
199+ segment_axis_names_kv = nn .logical_to_mesh_axes ((BATCH , KV_LENGTH ))
200+ q_segment_ids = jnp .where (jnp .arange (query .shape [2 ]) < query_seq_len , 1 , 0 )
201+ q_segment_ids = jnp .broadcast_to (q_segment_ids , (query .shape [0 ], q_segment_ids .shape [0 ]))
202+ kv_segment_ids = jnp .where (jnp .arange (key .shape [2 ]) < key_seq_len , 1 , 0 )
203+ kv_segment_ids = jnp .broadcast_to (kv_segment_ids , (query .shape [0 ], kv_segment_ids .shape [0 ]))
204+
197205 @functools .partial (
198206 shard_map .shard_map ,
199207 mesh = mesh ,
200- in_specs = (
201- q_axis_names ,
202- kv_axis_names ,
203- kv_axis_names ,
204- ),
208+ in_specs = (q_axis_names , kv_axis_names , kv_axis_names , segment_axis_names_q , segment_axis_names_kv ),
205209 out_specs = q_axis_names ,
206210 check_rep = False ,
207211 )
208- def wrap_flash_attention (query , key , value ):
212+ def wrap_flash_attention (query , key , value , q_segment_ids , kv_segment_ids ):
209213 mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], key .shape [2 ]))
210214 multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
211215 # make_splash_mha is wrapped around shardmap and seq and head is already
212216 # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
217+ segment_ids = splash_attention_kernel .SegmentIds (q = q_segment_ids , kv = kv_segment_ids )
213218 splash_kernel = splash_attention_kernel .make_splash_mha (
214219 mask = multi_head_mask ,
215220 head_shards = 1 , # the sizes of the axis is sharding over heads
216221 q_seq_shards = 1 , # the sizes of the axis is sharding over seq_len
217222 block_sizes = block_sizes ,
218223 )
219- attention_output = jax .vmap (splash_kernel )(query , key , value )
224+ attention_output = jax .vmap (splash_kernel )(query , key , value , segment_ids = segment_ids )
220225 return attention_output
221226
222227 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
@@ -227,7 +232,7 @@ def wrap_flash_attention(query, key, value):
227232 "Warning, batch dimension should be shardable among the devices in data and fsdp"
228233 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
229234 )
230- x = wrap_flash_attention (query , key , value )
235+ x = wrap_flash_attention (query , key , value , q_segment_ids , kv_segment_ids )
231236 x = x [:, :, :query_seq_len , :kv_size ]
232237 x = _reshape_heads_to_head_dim (x )
233238
0 commit comments