Skip to content

Commit b7f2f78

Browse files
committed
Fix
1 parent b3c7212 commit b7f2f78

1 file changed

Lines changed: 38 additions & 36 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -281,45 +281,47 @@ def wrap_flash_attention(query, key, value):
281281
block_q = max(*block_q_sizes)
282282
block_kv = max(*block_kv_sizes)
283283

284+
# FIX: Always pad data. The kernel requires seq_len % block_size == 0.
285+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
286+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
287+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
288+
284289
if attention_kernel == "tokamax_flash":
285-
# OPTIMIZATION: Skip padding and segment_ids for the optimized kernel
286-
kv_size = key.shape[-1]
287-
query_seq_len = query.shape[2]
288-
segment_ids = None
290+
# OPTIMIZATION: We pad the data (required), but we skip
291+
# calculating 'segment_ids' (overhead), relying on the kernel's
292+
# internal masking for the padded regions.
293+
segment_ids = None
294+
295+
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
296+
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
297+
mask=mask,
298+
q_seq_shards=1,
299+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
300+
save_residuals=False,
301+
)
289302

290-
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
291-
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
292-
mask=mask,
293-
q_seq_shards=1,
294-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
295-
save_residuals=False, # Ring attention not typically used in this path
296-
)
297303
else:
298-
# STANDARD PATH: Explicit padding (Slower)
299-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
300-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
301-
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
302-
303-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
304-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
305-
306-
q_padded_len = query.shape[2]
307-
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
308-
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
309-
310-
kv_padded_len = key.shape[2]
311-
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
312-
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
313-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
314-
315-
splash_kernel = splash_attention_kernel.make_splash_mha(
316-
mask=multi_head_mask,
317-
head_shards=1,
318-
q_seq_shards=1,
319-
block_sizes=block_sizes,
320-
save_residuals=True if attention_kernel == "ring" else False,
321-
residual_checkpoint_name=residual_checkpoint_name
322-
)
304+
# STANDARD PATH: Explicit Padding + Segment IDs
305+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
306+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
307+
308+
q_padded_len = query.shape[2]
309+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
310+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
311+
312+
kv_padded_len = key.shape[2]
313+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
314+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
315+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
316+
317+
splash_kernel = splash_attention_kernel.make_splash_mha(
318+
mask=multi_head_mask,
319+
head_shards=1,
320+
q_seq_shards=1,
321+
block_sizes=block_sizes,
322+
save_residuals=True if attention_kernel == "ring" else False,
323+
residual_checkpoint_name=residual_checkpoint_name
324+
)
323325
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
324326

325327
if not mask_padding_tokens:

0 commit comments

Comments
 (0)