@@ -173,25 +173,54 @@ def _tpu_flash_attention(
173173 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute )
174174
175175 axis_names = nn .logical_to_mesh_axes (flash_axis_names )
176+ kv_axis_names = nn .logical_to_mesh_axes ((BATCH , HEAD , None , D_KV ))
177+ flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH )
178+ axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
179+ named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
180+
181+ cp_size = 8
176182
177183 @functools .partial (
178- shard_map .shard_map ,
179- mesh = mesh ,
180- in_specs = (
181- axis_names ,
182- axis_names ,
183- axis_names ,
184- ),
185- out_specs = axis_names ,
186- check_rep = False ,
184+ jax .jit ,
185+ static_argnames = [
186+ "multi_head_mask" ,
187+ "shard_head_size"
188+ ],
187189 )
188- def wrap_flash_attention (query , key , value ):
189- masks = [splash_attention_mask .FullMask (_shape = (query .shape [2 ], query .shape [2 ])) for _ in range (query .shape [1 ])]
190- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = masks )
190+ def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
191191 splash_kernel = splash_attention_kernel .make_splash_mha (
192- mask = multi_head_mask , head_shards = 1 , q_seq_shards = 1 , block_sizes = block_sizes
192+ mask = multi_head_mask ,
193+ head_shards = shard_head_size , # the sizes of the axis is sharding over heads
194+ q_seq_shards = cp_size ,
195+ block_sizes = block_sizes ,
193196 )
194- return jax .vmap (splash_kernel )(query , key , value )
197+ return splash_kernel
198+
199+ shard_head_size = 1
200+ mask = splash_attention_mask .FullMask (_shape = (query .shape [2 ], query .shape [2 ]))
201+ mask &= splash_attention_mask .LocalMask (
202+ shape = (query .shape [2 ], key .shape [2 ]),
203+ window_size = (query .shape [2 ], query .shape [2 ]),
204+ offset = 0
205+ )
206+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = (mask ,) * query .shape [1 ])
207+ splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
208+ segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
209+ @functools .partial (
210+ shard_map .shard_map ,
211+ mesh = mesh ,
212+ in_specs = (
213+ axis_names ,
214+ kv_axis_names ,
215+ kv_axis_names ,
216+ segment_axis_names_splash_kernel ,
217+ ),
218+ out_specs = axis_names ,
219+ check_rep = False
220+ )
221+ def wrap_flash_attention (query , key , value , splash_kernel ):
222+ attention_output = jax .vmap (splash_kernel )(query , key , value )
223+ return attention_output
195224
196225 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
197226 # This warning might show up when doing model eval for example, when calculating model flops
@@ -201,7 +230,7 @@ def wrap_flash_attention(query, key, value):
201230 "Warning, batch dimension should be shardable among the devices in data and fsdp"
202231 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
203232 )
204- x = wrap_flash_attention (query , key , value )
233+ x = wrap_flash_attention (query , key , value , splash_kernel )
205234 x = x [:, :, :query_seq_len , :kv_size ]
206235 x = _reshape_heads_to_head_dim (x )
207236
0 commit comments