Skip to content

Commit 4af41f6

Browse files
committed
add split self/cross attention sharding rules and configuration
1 parent 015ccc2 commit 4af41f6

1 file changed

Lines changed: 97 additions & 18 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 97 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from jax.experimental import shard_map
2626
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
2727
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
28+
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
29+
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
2830
from einops import rearrange
2931
from .. import common_types, max_logging
3032

@@ -46,6 +48,13 @@
4648
EMBED = common_types.EMBED
4749
Quant = quantizations.AqtQuantization
4850

51+
SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD
52+
SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH
53+
SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH
54+
CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD
55+
CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH
56+
CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH
57+
4958

5059
def _maybe_aqt_einsum(quant: Quant):
5160
return jnp.einsum if quant is None else quant.einsum()
@@ -163,6 +172,40 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
163172

164173
return tensor, kv_size, seq_len
165174

175+
def convert_to_tokamax_splash_config( block_sizes: BlockSizes,
176+
q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
177+
k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
178+
v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR,
179+
residual_checkpoint_name: str | None = None,
180+
attn_logits_soft_cap: float | None = None,
181+
fuse_reciprocal: bool = True,
182+
use_base2_exp: bool = False,
183+
max_logit_const: float | None = None,
184+
interpret: bool = False,
185+
dq_reduction_steps: int | None = None) -> tokamax_splash_attention_kernel.SplashConfig:
186+
assert block_sizes.use_fused_bwd_kernel, "Tokamax Splash attention only supports fused bwd kernel."
187+
return tokamax_splash_attention_kernel.SplashConfig(
188+
block_q=block_sizes.block_q,
189+
block_kv=block_sizes.block_kv,
190+
block_kv_compute=block_sizes.block_kv_compute,
191+
block_q_dkv=block_sizes.block_q_dkv,
192+
block_kv_dkv=block_sizes.block_kv_dkv,
193+
block_kv_dkv_compute=block_sizes.block_kv_dkv_compute,
194+
block_q_dq= None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq,
195+
block_kv_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_kv_dq,
196+
use_fused_bwd_kernel=block_sizes.use_fused_bwd_kernel,
197+
q_layout=q_layout,
198+
k_layout=k_layout,
199+
v_layout=v_layout,
200+
residual_checkpoint_name=residual_checkpoint_name,
201+
attn_logits_soft_cap=attn_logits_soft_cap,
202+
fuse_reciprocal=fuse_reciprocal,
203+
use_base2_exp=use_base2_exp,
204+
max_logit_const=max_logit_const,
205+
interpret=interpret,
206+
dq_reduction_steps=dq_reduction_steps,
207+
)
208+
166209

167210
def _tpu_flash_attention(
168211
query: jax.Array,
@@ -175,6 +218,7 @@ def _tpu_flash_attention(
175218
flash_block_sizes: BlockSizes,
176219
dtype: jnp.dtype = jnp.float32,
177220
attention_kernel: str = "flash",
221+
mask_padding_tokens: bool = True,
178222
residual_checkpoint_name: str | None = None,
179223
) -> jax.Array:
180224
"""TPU Flash Attention"""
@@ -186,17 +230,19 @@ def _tpu_flash_attention(
186230
kv_max_block_size = key.shape[1]
187231
else:
188232
kv_max_block_size = q_max_block_size
189-
if flash_block_sizes:
233+
# ensure that for cross attention we override the block sizes.
234+
if flash_block_sizes and key.shape[1] == query.shape[1]:
190235
block_sizes = flash_block_sizes
191236
else:
237+
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
192238
block_sizes = splash_attention_kernel.BlockSizes(
193-
block_q=min(q_max_block_size, query.shape[2]),
239+
block_q=block_size_q,
194240
block_kv_compute=min(kv_max_block_size, key.shape[2]),
195241
block_kv=min(kv_max_block_size, key.shape[2]),
196-
block_q_dkv=min(q_max_block_size, query.shape[2]),
242+
block_q_dkv=block_size_q,
197243
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
198244
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
199-
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
245+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q,
200246
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
201247
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
202248
)
@@ -253,17 +299,28 @@ def wrap_flash_attention(query, key, value):
253299

254300
# make_splash_mha is wrapped around shardmap and seq and head is already
255301
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
256-
splash_kernel = splash_attention_kernel.make_splash_mha(
257-
mask=multi_head_mask,
258-
head_shards=1, # the sizes of the axis is sharding over heads
259-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
260-
block_sizes=block_sizes,
261-
save_residuals=True if attention_kernel == "ring" else False,
262-
residual_checkpoint_name=residual_checkpoint_name,
263-
)
302+
if attention_kernel == "tokamax_flash":
303+
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
304+
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
305+
mask=mask,
306+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
307+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
308+
save_residuals=True if attention_kernel == "ring" else False,
309+
)
310+
else:
311+
splash_kernel = splash_attention_kernel.make_splash_mha(
312+
mask=multi_head_mask,
313+
head_shards=1, # the sizes of the axis is sharding over heads
314+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
315+
block_sizes=block_sizes,
316+
save_residuals=True if attention_kernel == "ring" else False,
317+
residual_checkpoint_name=residual_checkpoint_name
318+
)
264319
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
265320

266-
if attention_kernel == "flash":
321+
if not mask_padding_tokens:
322+
segment_ids = None
323+
if attention_kernel in ["flash", "tokamax_flash"]:
267324
attention_output = vmapped_splash(query, key, value, segment_ids)
268325
else:
269326
if num_fsdp_shards > 1:
@@ -302,6 +359,8 @@ def ring_scan_body(carry, _):
302359
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
303360

304361
attention_output = o_final / l_final[..., None]
362+
else:
363+
raise ValueError("ring attention requires fsdp > 1")
305364

306365
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
307366

@@ -442,14 +501,15 @@ def _apply_attention(
442501
axis_names_kv: AxisNames,
443502
flash_block_sizes: BlockSizes,
444503
dpa_layer: Callable,
504+
mask_padding_tokens: bool = True,
445505
residual_checkpoint_name: str | None = None,
446506
):
447507
"""Routes to different attention kernels."""
448508
_check_attention_inputs(query, key, value)
449509
seq_len_idx = 1
450510
if query.ndim == 4:
451511
seq_len_idx = 2
452-
if attention_kernel == "flash":
512+
if attention_kernel in ["flash", "tokamax_flash"]:
453513
can_use_flash_attention = (
454514
query.shape[seq_len_idx] >= flash_min_seq_length
455515
and key.shape[seq_len_idx] >= flash_min_seq_length
@@ -461,7 +521,7 @@ def _apply_attention(
461521
return _apply_attention_dot(
462522
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
463523
)
464-
elif attention_kernel == "flash":
524+
elif attention_kernel in ["flash", "tokamax_flash"]:
465525
return _tpu_flash_attention(
466526
query,
467527
key * scale,
@@ -472,11 +532,14 @@ def _apply_attention(
472532
axis_names_kv,
473533
flash_block_sizes,
474534
dtype,
535+
attention_kernel,
536+
mask_padding_tokens=mask_padding_tokens,
475537
residual_checkpoint_name=residual_checkpoint_name,
476538
)
477539
elif attention_kernel == "ring":
478540
return _tpu_flash_attention(
479-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel
541+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
542+
mask_padding_tokens=mask_padding_tokens,
480543
)
481544
elif attention_kernel == "cudnn_flash_te":
482545
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
@@ -607,6 +670,7 @@ def __init__(
607670
flash_block_sizes: BlockSizes = None,
608671
dtype: DType = jnp.float32,
609672
quant: Quant = None,
673+
mask_padding_tokens: bool = True,
610674
residual_checkpoint_name: str | None = None,
611675
):
612676
self.dpa_layer = None
@@ -627,6 +691,7 @@ def __init__(
627691
self.flash_block_sizes = flash_block_sizes
628692
self.dtype = dtype
629693
self.quant = quant
694+
self.mask_padding_tokens = mask_padding_tokens
630695
self.residual_checkpoint_name = residual_checkpoint_name
631696

632697
def apply_attention(self, query: Array, key: Array, value: Array):
@@ -648,6 +713,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
648713
axis_names_kv=self.axis_names_kv,
649714
flash_block_sizes=self.flash_block_sizes,
650715
dpa_layer=self.dpa_layer,
716+
mask_padding_tokens=self.mask_padding_tokens,
651717
residual_checkpoint_name=self.residual_checkpoint_name,
652718
)
653719

@@ -737,6 +803,8 @@ def __init__(
737803
precision: jax.lax.Precision = None,
738804
qkv_bias: bool = False,
739805
quant: Quant = None,
806+
is_self_attention: bool = True,
807+
mask_padding_tokens: bool = True,
740808
residual_checkpoint_name: str | None = None,
741809
enable_jax_named_scopes: bool = False,
742810
):
@@ -750,11 +818,19 @@ def __init__(
750818
self.inner_dim = dim_head * heads
751819
scale = dim_head**-0.5
752820
self.qk_norm = qk_norm
753-
self.enable_jax_named_scopes = enable_jax_named_scopes
821+
754822
self.query_axis_names = query_axis_names
755823
self.key_axis_names = key_axis_names
756824
self.value_axis_names = value_axis_names
757825
self.out_axis_names = out_axis_names
826+
self.enable_jax_named_scopes = enable_jax_named_scopes
827+
828+
if is_self_attention:
829+
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
830+
axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV)
831+
else:
832+
axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV)
833+
axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV)
758834

759835
self.attention_op = NNXAttentionOp(
760836
mesh=mesh,
@@ -765,10 +841,13 @@ def __init__(
765841
use_memory_efficient_attention=use_memory_efficient_attention,
766842
split_head_dim=split_head_dim,
767843
float32_qk_product=False,
844+
axis_names_q=axis_names_q,
845+
axis_names_kv=axis_names_kv,
768846
flash_min_seq_length=flash_min_seq_length,
769847
flash_block_sizes=flash_block_sizes,
770848
dtype=dtype,
771849
quant=quant,
850+
mask_padding_tokens=mask_padding_tokens,
772851
residual_checkpoint_name=residual_checkpoint_name,
773852
)
774853
# None axes corresponds to the stacked weights across all blocks
@@ -1579,4 +1658,4 @@ def setup(self):
15791658
def __call__(self, hidden_states, deterministic=True):
15801659
hidden_states = self.proj(hidden_states)
15811660
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1582-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1661+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

0 commit comments

Comments
 (0)