Skip to content

Commit 0abc904

Browse files
committed
Tokamax splash attn
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 54ae6e7 commit 0abc904

2 files changed

Lines changed: 62 additions & 14 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,15 +494,17 @@ def get_flash_block_sizes(config):
494494
"""Create custom flash attention BlockSizes."""
495495
flash_block_sizes = None
496496
if len(config.flash_block_sizes.keys()) > 0:
497+
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
497498
flash_block_sizes = splash_attention_kernel.BlockSizes(
498499
block_q=int(config.flash_block_sizes["block_q"]),
499500
block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]),
500501
block_kv=int(config.flash_block_sizes["block_kv"]),
501502
block_q_dkv=config.flash_block_sizes.get("block_q_dkv"),
502503
block_kv_dkv=config.flash_block_sizes.get("block_kv_dkv"),
503504
block_kv_dkv_compute=config.flash_block_sizes.get("block_kv_dkv_compute"),
504-
block_q_dq=config.flash_block_sizes.get("block_q_dq"),
505-
block_kv_dq=config.flash_block_sizes.get("block_kv_dq"),
505+
block_q_dq=config.flash_block_sizes.get("block_q_dq") if not use_fused_bwd_kernel else None,
506+
block_kv_dq=config.flash_block_sizes.get("block_kv_dq") if not use_fused_bwd_kernel else None,
507+
use_fused_bwd_kernel=use_fused_bwd_kernel,
506508
)
507509
return flash_block_sizes
508510

src/maxdiffusion/models/attention_flax.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from jax.experimental import shard_map
2525
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2626
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
27+
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
28+
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
2729
from einops import rearrange
2830
from .. import common_types, max_logging
2931

@@ -169,6 +171,40 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
169171

170172
return tensor, kv_size, seq_len
171173

174+
def convert_to_tokamax_splash_config( block_sizes: BlockSizes,
175+
q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
176+
k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
177+
v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
178+
residual_checkpoint_name: str | None = None,
179+
attn_logits_soft_cap: float | None = None,
180+
fuse_reciprocal: bool = True,
181+
use_base2_exp: bool = False,
182+
max_logit_const: float | None = None,
183+
interpret: bool = False,
184+
dq_reduction_steps: int | None = None) -> tokamax_splash_attention_kernel.SplashConfig:
185+
assert block_sizes.use_fused_bwd_kernel, "Tokamax Splash attention only supports fused bwd kernel."
186+
return tokamax_splash_attention_kernel.SplashConfig(
187+
block_q=block_sizes.block_q,
188+
block_kv=block_sizes.block_kv,
189+
block_kv_compute=block_sizes.block_kv_compute,
190+
block_q_dkv=block_sizes.block_q_dkv,
191+
block_kv_dkv=block_sizes.block_kv_dkv,
192+
block_kv_dkv_compute=block_sizes.block_kv_dkv_compute,
193+
block_q_dq= None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq,
194+
block_kv_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_kv_dq,
195+
use_fused_bwd_kernel=block_sizes.use_fused_bwd_kernel,
196+
q_layout=q_layout,
197+
k_layout=k_layout,
198+
v_layout=v_layout,
199+
residual_checkpoint_name=residual_checkpoint_name,
200+
attn_logits_soft_cap=attn_logits_soft_cap,
201+
fuse_reciprocal=fuse_reciprocal,
202+
use_base2_exp=use_base2_exp,
203+
max_logit_const=max_logit_const,
204+
interpret=interpret,
205+
dq_reduction_steps=dq_reduction_steps,
206+
)
207+
172208

173209
def _tpu_flash_attention(
174210
query: jax.Array,
@@ -203,8 +239,9 @@ def _tpu_flash_attention(
203239
block_q_dkv=min(q_max_block_size, query.shape[2]),
204240
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
205241
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
206-
block_q_dq=min(q_max_block_size, query.shape[2]),
207-
block_kv_dq=min(kv_max_block_size, query.shape[2]),
242+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
243+
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
244+
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
208245
)
209246
num_fsdp_shards = mesh.shape["fsdp"]
210247
query = _reshape_data_for_flash(query, heads)
@@ -240,18 +277,27 @@ def wrap_flash_attention(query, key, value):
240277

241278
# make_splash_mha is wrapped around shardmap and seq and head is already
242279
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
243-
splash_kernel = splash_attention_kernel.make_splash_mha(
244-
mask=multi_head_mask,
245-
head_shards=1, # the sizes of the axis is sharding over heads
246-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
247-
block_sizes=block_sizes,
248-
save_residuals=True if attention_kernel == "ring" else False,
249-
)
280+
if attention_kernel == "tokamax_flash":
281+
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
282+
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
283+
mask=mask,
284+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
285+
config=convert_to_tokamax_splash_config(block_sizes),
286+
save_residuals=True if attention_kernel == "ring" else False,
287+
)
288+
else:
289+
splash_kernel = splash_attention_kernel.make_splash_mha(
290+
mask=multi_head_mask,
291+
head_shards=1, # the sizes of the axis is sharding over heads
292+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
293+
block_sizes=block_sizes,
294+
save_residuals=True if attention_kernel == "ring" else False,
295+
)
250296
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
251297

252298
if not mask_padding_tokens:
253299
segment_ids = None
254-
if attention_kernel == "flash":
300+
if attention_kernel in ["flash", "tokamax_flash"]:
255301
attention_output = vmapped_splash(query, key, value, segment_ids)
256302
else:
257303
if num_fsdp_shards > 1:
@@ -439,7 +485,7 @@ def _apply_attention(
439485
seq_len_idx = 1
440486
if query.ndim == 4:
441487
seq_len_idx = 2
442-
if attention_kernel == "flash":
488+
if attention_kernel in ["flash", "tokamax_flash"]:
443489
can_use_flash_attention = (
444490
query.shape[seq_len_idx] >= flash_min_seq_length
445491
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -451,7 +497,7 @@ def _apply_attention(
451497
return _apply_attention_dot(
452498
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
453499
)
454-
elif attention_kernel == "flash":
500+
elif attention_kernel in ["flash", "tokamax_flash"]:
455501
return _tpu_flash_attention(
456502
query,
457503
key * scale,

0 commit comments

Comments
 (0)