@@ -233,14 +233,15 @@ def _tpu_flash_attention(
233233 if flash_block_sizes and key .shape [1 ] == query .shape [1 ]:
234234 block_sizes = flash_block_sizes
235235 else :
236+ block_size_q = flash_block_sizes .block_q if flash_block_sizes else q_max_block_size
236237 block_sizes = splash_attention_kernel .BlockSizes (
237- block_q = min ( q_max_block_size , query . shape [ 2 ]) ,
238+ block_q = block_size_q ,
238239 block_kv_compute = min (kv_max_block_size , key .shape [2 ]),
239240 block_kv = min (kv_max_block_size , key .shape [2 ]),
240- block_q_dkv = min ( q_max_block_size , query . shape [ 2 ]) ,
241+ block_q_dkv = block_size_q ,
241242 block_kv_dkv = min (kv_max_block_size , key .shape [2 ]),
242243 block_kv_dkv_compute = min (kv_max_block_size , query .shape [2 ]),
243- block_q_dq = None if attention_kernel == "tokamax_flash" else min ( q_max_block_size , query . shape [ 2 ]) ,
244+ block_q_dq = None if attention_kernel == "tokamax_flash" else block_size_q ,
244245 block_kv_dq = None if attention_kernel == "tokamax_flash" else min (kv_max_block_size , query .shape [2 ]),
245246 use_fused_bwd_kernel = True if attention_kernel == "tokamax_flash" else False ,
246247 )
0 commit comments