Skip to content

Commit aa442f9

Browse files
committed
force splash attention for cross attention.
1 parent e9eb4ca commit aa442f9

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def _apply_attention(
380380
)
381381
else:
382382
can_use_flash_attention = True
383-
can_use_flash_attention=True
384383
if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention:
385384
return _apply_attention_dot(
386385
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
@@ -513,7 +512,8 @@ def __init__(
513512
float32_qk_product: bool = True,
514513
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
515514
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
516-
flash_min_seq_length: int = 4096,
515+
# Uses splash attention on cross attention.
516+
flash_min_seq_length: int = 0,
517517
flash_block_sizes: BlockSizes = None,
518518
dtype: DType = jnp.float32,
519519
quant: Quant = None,

0 commit comments

Comments
 (0)