Skip to content

Commit ed47e5f

Browse files
committed
moving kernel init outside the sharding map
1 parent 70ce989 commit ed47e5f

1 file changed

Lines changed: 42 additions & 13 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
3030
from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
31+
from tokamax._src.ops.experimental.tpu.splash_attention import base as tokamax_base
3132
from einops import rearrange
3233
from .. import common_types, max_logging
3334

@@ -279,17 +280,49 @@ def _tpu_flash_attention(
279280
query = _reshape_data_for_flash(query, heads)
280281
key = _reshape_data_for_flash(key, heads)
281282
value = _reshape_data_for_flash(value, heads)
283+
284+
# Pre-padding and Ring Kernel creation outside shard_map
285+
if attention_kernel == "tokamax_ring":
286+
block_q = max(block_sizes.block_q, block_sizes.block_q_dkv)
287+
block_kv = max(block_sizes.block_kv, block_sizes.block_kv_dkv)
288+
289+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q, num_shards=num_fsdp_shards)
290+
key, _, _ = _pad_data_for_flash(key, heads, block_kv, num_shards=num_fsdp_shards)
291+
value, _, _ = _pad_data_for_flash(value, heads, block_kv, num_shards=num_fsdp_shards)
292+
293+
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
294+
ring_kernel = tokamax_ring_attention_kernel.make_ring_attention(
295+
mask=mask,
296+
is_mqa=False,
297+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
298+
save_residuals=True,
299+
ring_axis="fsdp",
300+
q_seq_shards=num_fsdp_shards,
301+
kv_seq_shards=num_fsdp_shards,
302+
)
303+
kernel_spec = ring_kernel.manual_sharding_spec()
304+
else:
305+
# Logic for other kernels remains unchanged regarding local padding
306+
ring_kernel = None
307+
kernel_spec = None
308+
282309
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
283310
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
284311

285312
@functools.partial(
286313
shard_map.shard_map,
287314
mesh=mesh,
288-
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
315+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names, kernel_spec),
289316
out_specs=q_axis_names,
290317
check_rep=False,
291318
)
292-
def wrap_flash_attention(query, key, value):
319+
def wrap_flash_attention(query, key, value, ring_kernel):
320+
321+
if attention_kernel == "tokamax_ring":
322+
# For bidirectional attention, segment_ids can be None to hit the performance shortcut
323+
segment_ids = None
324+
vmapped_splash = jax.vmap(ring_kernel, in_axes=(0, 0, 0, None))
325+
return vmapped_splash(query, key, value, segment_ids)
293326

294327
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
295328
block_q_sizes = (
@@ -324,6 +357,7 @@ def wrap_flash_attention(query, key, value):
324357
kv_padded_len = key.shape[2]
325358
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
326359
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
360+
327361
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
328362

329363
# make_splash_mha is wrapped around shardmap and seq and head is already
@@ -336,15 +370,6 @@ def wrap_flash_attention(query, key, value):
336370
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
337371
save_residuals=True if "ring" in attention_kernel else False,
338372
)
339-
elif attention_kernel == "tokamax_ring":
340-
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
341-
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
342-
mask=mask,
343-
is_mqa=False,
344-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
345-
save_residuals=True,
346-
ring_axis="fsdp",
347-
)
348373
else:
349374
splash_kernel = splash_attention_kernel.make_splash_mha(
350375
mask=multi_head_mask,
@@ -360,7 +385,7 @@ def wrap_flash_attention(query, key, value):
360385

361386
if not mask_padding_tokens:
362387
segment_ids = None
363-
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
388+
if attention_kernel in ["flash", "tokamax_flash"]:
364389
attention_output = vmapped_splash(query, key, value, segment_ids)
365390
else:
366391
if num_fsdp_shards > 1:
@@ -412,7 +437,11 @@ def ring_scan_body(carry, _):
412437
"Warning, batch dimension should be shardable among the devices in data and fsdp"
413438
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
414439
)
415-
x = wrap_flash_attention(query, key, value)
440+
x = wrap_flash_attention(query, key, value, ring_kernel)
441+
442+
if attention_kernel == "tokamax_ring":
443+
x = x[:, :, :query_seq_len, :kv_size].astype(query.dtype)
444+
416445
x = _reshape_heads_to_head_dim(x)
417446

418447
return x

0 commit comments

Comments
 (0)