3939AxisNames = common_types .AxisNames
4040BATCH = common_types .BATCH
4141LENGTH = common_types .LENGTH
42+ Q_LENGTH = common_types .Q_LENGTH
43+ KV_LENGTH = common_types .KV_LENGTH
4244HEAD = common_types .HEAD
4345D_KV = common_types .D_KV
4446EMBED = common_types .EMBED
@@ -139,50 +141,87 @@ def _tpu_flash_attention(
139141 value : jax .Array ,
140142 heads : int ,
141143 mesh : Mesh ,
142- flash_axis_names : AxisNames ,
143- flash_block_sizes : BlockSizes ,
144+ flash_block_sizes : BlockSizes = None ,
145+ flash_axis_names_kv : AxisNames = (BATCH , HEAD , KV_LENGTH , D_KV ),
146+ flash_axis_names_q : AxisNames = (BATCH , HEAD , LENGTH , D_KV ),
147+ flash_axis_names_splash_kernel : AxisNames = (HEAD , LENGTH ),
144148 dtype : jnp .dtype = jnp .float32 ) -> jax .Array :
145149 """TPU Flash Attention"""
146150
147- max_block_size = 1024 if dtype == jnp .bfloat16 else 512
151+ cp_size = mesh .shape ["context" ]
152+ #breakpoint()
153+ axis_names_splash_kernel = nn .logical_to_mesh_axes (flash_axis_names_splash_kernel )
154+ axis_names_q = nn .logical_to_mesh_axes (flash_axis_names_q )
155+ axis_names_kv = nn .logical_to_mesh_axes (flash_axis_names_kv )
156+ max_logging .log (f"axis_names_q: { axis_names_q } " )
157+ max_logging .log (f"axis_names_kv: { axis_names_kv } " )
158+ max_logging .log (f"axis_names_splash_kernel: { axis_names_splash_kernel } " )
159+
160+ max_block_size = 256 if dtype == jnp .bfloat16 else 128
148161 if flash_block_sizes :
149162 block_sizes = flash_block_sizes
150163 else :
151164 block_sizes = splash_attention_kernel .BlockSizes (
152165 block_q = min (max_block_size , query .shape [2 ]),
153- block_kv_compute = min (max_block_size , key .shape [2 ]),
154166 block_kv = min (max_block_size , key .shape [2 ]),
167+ block_kv_compute = min (max_block_size , key .shape [2 ]),
155168 block_q_dkv = min (max_block_size , query .shape [2 ]),
156169 block_kv_dkv = min (max_block_size , key .shape [2 ]),
157170 block_kv_dkv_compute = min (max_block_size , query .shape [2 ]),
158171 block_q_dq = min (max_block_size , query .shape [2 ]),
159172 block_kv_dq = min (max_block_size , query .shape [2 ]),
173+ q_layout = splash_attention_kernel .QKVLayout ["HEAD_DIM_MINOR" ],
174+ k_layout = splash_attention_kernel .QKVLayout ["HEAD_DIM_MINOR" ],
175+ v_layout = splash_attention_kernel .QKVLayout ["HEAD_DIM_MINOR" ],
160176 )
161177
162178 query , kv_size , query_seq_len = _reshape_data_for_flash (query , heads , block_sizes .block_q )
163179 key , _ , _ = _reshape_data_for_flash (key , heads , block_sizes .block_kv_compute )
164180 value , _ , _ = _reshape_data_for_flash (value , heads , block_sizes .block_kv_compute )
165181
166- axis_names = nn .logical_to_mesh_axes (flash_axis_names )
167-
168182 @functools .partial (
169- shard_map .shard_map ,
170- mesh = mesh ,
171- in_specs = (
172- axis_names ,
173- axis_names ,
174- axis_names ,
175- ),
176- out_specs = axis_names ,
177- check_rep = False ,
183+ jax .jit ,
184+ static_argnames = [
185+ "multi_head_mask" ,
186+ "shard_head_size"
187+ ],
178188 )
179- def wrap_flash_attention (query , key , value ):
180- masks = [splash_attention_mask .FullMask (_shape = (query .shape [2 ], query .shape [2 ])) for _ in range (query .shape [1 ])]
181- multi_head_mask = splash_attention_mask .MultiHeadMask (masks = masks )
189+ def wrap_splash_kernel (multi_head_mask , shard_head_size = 1 ):
190+ # breakpoint()
182191 splash_kernel = splash_attention_kernel .make_splash_mha (
183- mask = multi_head_mask , head_shards = 1 , q_seq_shards = 1 , block_sizes = block_sizes
192+ mask = multi_head_mask ,
193+ head_shards = shard_head_size , # the sizes of the axis is sharding over heads
194+ q_seq_shards = cp_size ,
195+ block_sizes = block_sizes ,
184196 )
185- return jax .vmap (splash_kernel )(query , key , value )
197+ return splash_kernel
198+
199+ # logical_axis_rules_head = np.array(
200+ # [mesh.shape[physical_axes] for physical_axes in dict(config.logical_axis_rules)[HEAD]]
201+ # )
202+ shard_head_size = 1
203+
204+ masks = [splash_attention_mask .FullMask (_shape = (query .shape [2 ], query .shape [2 ])) for _ in range (query .shape [1 ])]
205+ multi_head_mask = splash_attention_mask .MultiHeadMask (masks = masks )
206+ splash_kernel = wrap_splash_kernel (multi_head_mask , int (shard_head_size ))
207+ named_sharding = jax .sharding .NamedSharding (mesh , axis_names_splash_kernel )
208+ segment_axis_names_splash_kernel = splash_kernel .manual_sharding_spec (named_sharding )
209+ @functools .partial (
210+ shard_map .shard_map ,
211+ mesh = mesh ,
212+ in_specs = (
213+ axis_names_q ,
214+ axis_names_kv ,
215+ axis_names_kv ,
216+ segment_axis_names_splash_kernel ,
217+ None
218+ ),
219+ out_specs = axis_names_q ,
220+ check_rep = False
221+ )
222+ def wrap_flash_attention (query , key , value , splash_kernel , cp_size ):
223+ attention_output = jax .vmap (splash_kernel )(query , key , value )
224+ return attention_output
186225
187226 devices_in_data_fsdp = mesh .shape ["data" ] * mesh .shape ["fsdp" ]
188227 # This warning might show up when doing model eval for example, when calculating model flops
@@ -192,7 +231,7 @@ def wrap_flash_attention(query, key, value):
192231 "Warning, batch dimension should be shardable among the devices in data and fsdp"
193232 f" axis, batch dimension: { query .shape [0 ]} , devices_in_data_fsdp: { devices_in_data_fsdp } "
194233 )
195- x = wrap_flash_attention (query , key , value )
234+ x = wrap_flash_attention (query , key , value , splash_kernel , cp_size )
196235 x = x [:, :, :query_seq_len , :kv_size ]
197236 x = _reshape_heads_to_head_dim (x )
198237
@@ -343,7 +382,15 @@ def _apply_attention(
343382 if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention :
344383 return _apply_attention_dot (query , key , value , dtype , heads , dim_head , scale , split_head_dim , float32_qk_product , use_memory_efficient_attention )
345384 elif attention_kernel == "flash" :
346- return _tpu_flash_attention (query , key * scale , value , heads , mesh , flash_axis_names , flash_block_sizes , dtype )
385+ return _tpu_flash_attention (
386+ query = query ,
387+ key = key * scale ,
388+ value = value ,
389+ heads = heads ,
390+ mesh = mesh ,
391+ flash_block_sizes = flash_block_sizes ,
392+ dtype = dtype
393+ )
347394 elif attention_kernel == "cudnn_flash_te" :
348395 return _cudnn_flash_attention (query , key , value , heads , mesh , dpa_layer )
349396 else :
0 commit comments