Skip to content

Commit 0dda0cf

Browse files
committed
Test on attention type and automatically modify flash block sizes object when 'tokamax_flash' requested
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 3386550 commit 0dda0cf

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,15 @@ def _tpu_flash_attention(
233233
if flash_block_sizes and key.shape[1] == query.shape[1]:
234234
block_sizes = flash_block_sizes
235235
else:
236+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
236237
block_sizes = splash_attention_kernel.BlockSizes(
237-
block_q=min(q_max_block_size, query.shape[2]),
238+
block_q=block_size_q,
238239
block_kv_compute=min(kv_max_block_size, key.shape[2]),
239240
block_kv=min(kv_max_block_size, key.shape[2]),
240-
block_q_dkv=min(q_max_block_size, query.shape[2]),
241+
block_q_dkv=block_size_q,
241242
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
242243
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
243-
block_q_dq=None if attention_kernel == "tokamax_flash" else min(q_max_block_size, query.shape[2]),
244+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
244245
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
245246
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
246247
)

0 commit comments

Comments
 (0)