Skip to content

Commit 618cb2e

Browse files
committed
Add SplashAttention scheduler config and related improvements
1 parent f938dcb commit 618cb2e

3 files changed

Lines changed: 3 additions & 0 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,7 @@ use_max_logit_estimate: -1 # -1 means no estimate, any > 0 value will be used as
909909
cost_estimate_flops_fwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (forward)
910910
cost_estimate_flops_bwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (backward)
911911
dq_reduction_steps: 0 #the number of reduction steps. For now, only 3 or all the kv steps are supported.
912+
use_splash_scheduler: False # to use tokamax splash attention scheduler.
912913
### Determine if we want to use load balance for context parallelism
913914
context_parallel_load_balance: True
914915
context_parallel_strategy: "all_gather" # "all_gather" or "ring"

src/MaxText/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ class SplashAttention(BaseModel):
561561
0,
562562
description="the number of reduction steps. For now, only 3 or all the kv steps are supported.",
563563
)
564+
use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.")
564565

565566

566567
class PagedAttention(BaseModel):

src/MaxText/layers/attention_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,7 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
11231123
if config.cost_estimate_flops_bwd >= 0
11241124
else None,
11251125
dq_reduction_steps=config.dq_reduction_steps if config.dq_reduction_steps > 0 else None,
1126+
use_experimental_scheduler=config.use_splash_scheduler,
11261127
)
11271128
else:
11281129
sa_config = splash_attention_kernel.BlockSizes(

0 commit comments

Comments
 (0)