Skip to content

Commit 436e7d1

Browse files
committed
Test on attention type and automatically modify flash block sizes object when 'tokamax_flash' requested
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent a79a49c commit 436e7d1

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def get_flash_block_sizes(config):
501501
"""Create custom flash attention BlockSizes."""
502502
flash_block_sizes = None
503503
if len(config.flash_block_sizes.keys()) > 0:
504-
attention_is_tokamax = "tokamax" in config.attention_kernel
504+
attention_is_tokamax = "tokamax" in config.attention
505505
user_block_sizes:Dict[str, int] = config.flash_block_sizes
506506
if attention_is_tokamax:
507507
max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."

0 commit comments

Comments
 (0)