@@ -270,38 +270,6 @@ def convert_to_tokamax_splash_config(
270270 )
271271
272272
273- def _resolve_tpu_attention_block_sizes (
274- query_seq_len : int ,
275- kv_seq_len : int ,
276- flash_block_sizes : BlockSizes ,
277- dtype : jnp .dtype ,
278- attention_kernel : str = "flash" ,
279- ) -> BlockSizes :
280- """Resolve TPU splash attention block sizes for self- and cross-attention."""
281- q_max_block_size = 1024 if dtype == jnp .bfloat16 else 512
282- is_cross_attention = kv_seq_len != query_seq_len
283- if is_cross_attention :
284- kv_max_block_size = ((kv_seq_len + 127 ) // 128 ) * 128
285- else :
286- kv_max_block_size = q_max_block_size
287-
288- if flash_block_sizes and not is_cross_attention :
289- return flash_block_sizes
290-
291- block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
292- return splash_attention_kernel .BlockSizes (
293- block_q = block_size_q ,
294- block_kv_compute = min (kv_max_block_size , kv_seq_len ),
295- block_kv = min (kv_max_block_size , kv_seq_len ),
296- block_q_dkv = block_size_q ,
297- block_kv_dkv = min (kv_max_block_size , kv_seq_len ),
298- block_kv_dkv_compute = min (kv_max_block_size , query_seq_len ),
299- block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
300- block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query_seq_len ),
301- use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
302- )
303-
304-
305273def _tpu_flash_attention (
306274 query : jax .Array ,
307275 key : jax .Array ,
@@ -319,18 +287,11 @@ def _tpu_flash_attention(
319287) -> jax .Array :
320288 """TPU Flash Attention"""
321289
322- block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , attention_kernel )
323290 num_context_shards = mesh .shape ["context" ]
324291 query , orig_q_seq_len = _reshape_data_for_flash (query , heads , num_context_shards )
325292 key , _ = _reshape_data_for_flash (key , heads , num_context_shards )
326293 value , _ = _reshape_data_for_flash (value , heads , num_context_shards )
327- block_sizes = _resolve_tpu_attention_block_sizes (
328- query_seq_len = query .shape [2 ],
329- kv_seq_len = key .shape [2 ],
330- flash_block_sizes = flash_block_sizes ,
331- dtype = dtype ,
332- attention_kernel = attention_kernel ,
333- )
294+ block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , attention_kernel )
334295
335296 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
336297 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
@@ -530,12 +491,7 @@ def _ulysses_attention(
530491 q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
531492 kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
532493
533- block_sizes = _resolve_tpu_attention_block_sizes (
534- query_seq_len = query .shape [2 ],
535- kv_seq_len = key .shape [2 ],
536- flash_block_sizes = flash_block_sizes ,
537- dtype = dtype ,
538- )
494+ block_sizes = _select_flash_block_sizes (query , key , flash_block_sizes , dtype , "flash" )
539495
540496 @functools .partial (
541497 jax .shard_map ,
0 commit comments