Skip to content

Commit 65e7f93

Browse files
committed
Revert "moving kernel init outside the sharding map"
This reverts commit ed47e5f.
1 parent ed47e5f commit 65e7f93

1 file changed

Lines changed: 13 additions & 42 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_attention_mask
2929
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_attention_kernel
3030
from tokamax._src.ops.experimental.tpu.splash_attention import ring_attention_kernel as tokamax_ring_attention_kernel
31-
from tokamax._src.ops.experimental.tpu.splash_attention import base as tokamax_base
3231
from einops import rearrange
3332
from .. import common_types, max_logging
3433

@@ -280,49 +279,17 @@ def _tpu_flash_attention(
280279
query = _reshape_data_for_flash(query, heads)
281280
key = _reshape_data_for_flash(key, heads)
282281
value = _reshape_data_for_flash(value, heads)
283-
284-
# Pre-padding and Ring Kernel creation outside shard_map
285-
if attention_kernel == "tokamax_ring":
286-
block_q = max(block_sizes.block_q, block_sizes.block_q_dkv)
287-
block_kv = max(block_sizes.block_kv, block_sizes.block_kv_dkv)
288-
289-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q, num_shards=num_fsdp_shards)
290-
key, _, _ = _pad_data_for_flash(key, heads, block_kv, num_shards=num_fsdp_shards)
291-
value, _, _ = _pad_data_for_flash(value, heads, block_kv, num_shards=num_fsdp_shards)
292-
293-
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
294-
ring_kernel = tokamax_ring_attention_kernel.make_ring_attention(
295-
mask=mask,
296-
is_mqa=False,
297-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
298-
save_residuals=True,
299-
ring_axis="fsdp",
300-
q_seq_shards=num_fsdp_shards,
301-
kv_seq_shards=num_fsdp_shards,
302-
)
303-
kernel_spec = ring_kernel.manual_sharding_spec()
304-
else:
305-
# Logic for other kernels remains unchanged regarding local padding
306-
ring_kernel = None
307-
kernel_spec = None
308-
309282
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
310283
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
311284

312285
@functools.partial(
313286
shard_map.shard_map,
314287
mesh=mesh,
315-
in_specs=(q_axis_names, kv_axis_names, kv_axis_names, kernel_spec),
288+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
316289
out_specs=q_axis_names,
317290
check_rep=False,
318291
)
319-
def wrap_flash_attention(query, key, value, ring_kernel):
320-
321-
if attention_kernel == "tokamax_ring":
322-
# For bidirectional attention, segment_ids can be None to hit the performance shortcut
323-
segment_ids = None
324-
vmapped_splash = jax.vmap(ring_kernel, in_axes=(0, 0, 0, None))
325-
return vmapped_splash(query, key, value, segment_ids)
292+
def wrap_flash_attention(query, key, value):
326293

327294
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
328295
block_q_sizes = (
@@ -357,7 +324,6 @@ def wrap_flash_attention(query, key, value, ring_kernel):
357324
kv_padded_len = key.shape[2]
358325
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
359326
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
360-
361327
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
362328

363329
# make_splash_mha is wrapped around shardmap and seq and head is already
@@ -370,6 +336,15 @@ def wrap_flash_attention(query, key, value, ring_kernel):
370336
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
371337
save_residuals=True if "ring" in attention_kernel else False,
372338
)
339+
elif attention_kernel == "tokamax_ring":
340+
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
341+
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
342+
mask=mask,
343+
is_mqa=False,
344+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
345+
save_residuals=True,
346+
ring_axis="fsdp",
347+
)
373348
else:
374349
splash_kernel = splash_attention_kernel.make_splash_mha(
375350
mask=multi_head_mask,
@@ -385,7 +360,7 @@ def wrap_flash_attention(query, key, value, ring_kernel):
385360

386361
if not mask_padding_tokens:
387362
segment_ids = None
388-
if attention_kernel in ["flash", "tokamax_flash"]:
363+
if attention_kernel in ["flash", "tokamax_flash", "tokamax_ring"]:
389364
attention_output = vmapped_splash(query, key, value, segment_ids)
390365
else:
391366
if num_fsdp_shards > 1:
@@ -437,11 +412,7 @@ def ring_scan_body(carry, _):
437412
"Warning, batch dimension should be shardable among the devices in data and fsdp"
438413
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
439414
)
440-
x = wrap_flash_attention(query, key, value, ring_kernel)
441-
442-
if attention_kernel == "tokamax_ring":
443-
x = x[:, :, :query_seq_len, :kv_size].astype(query.dtype)
444-
415+
x = wrap_flash_attention(query, key, value)
445416
x = _reshape_heads_to_head_dim(x)
446417

447418
return x

0 commit comments

Comments
 (0)