Skip to content

Commit b3c7212

Browse files
committed
Fix
1 parent c033b9a commit b3c7212

2 files changed

Lines changed: 1128 additions & 591 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -279,39 +279,43 @@ def wrap_flash_attention(query, key, value):
279279
block_kv_sizes += (block_sizes.block_kv_dq,)
280280

281281
block_q = max(*block_q_sizes)
282-
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
283-
284282
block_kv = max(*block_kv_sizes)
285-
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
286-
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
287-
288-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
289-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
290-
291-
q_padded_len = query.shape[2]
292-
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
293-
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
294283

295-
kv_padded_len = key.shape[2]
296-
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
297-
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
298-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
299-
300-
# make_splash_mha is wrapped around shardmap and seq and head is already
301-
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
302284
if attention_kernel == "tokamax_flash":
285+
# OPTIMIZATION: Skip padding and segment_ids for the optimized kernel
286+
kv_size = key.shape[-1]
287+
query_seq_len = query.shape[2]
288+
segment_ids = None
289+
303290
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
304291
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
305292
mask=mask,
306-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
293+
q_seq_shards=1,
307294
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
308-
save_residuals=True if attention_kernel == "ring" else False,
295+
save_residuals=False, # Ring attention not typically used in this path
309296
)
310297
else:
298+
# STANDARD PATH: Explicit padding (Slower)
299+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q)
300+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv)
301+
value, _, _ = _pad_data_for_flash(value, heads, block_kv)
302+
303+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
304+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
305+
306+
q_padded_len = query.shape[2]
307+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
308+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
309+
310+
kv_padded_len = key.shape[2]
311+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
312+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
313+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
314+
311315
splash_kernel = splash_attention_kernel.make_splash_mha(
312316
mask=multi_head_mask,
313-
head_shards=1, # the sizes of the axis is sharding over heads
314-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
317+
head_shards=1,
318+
q_seq_shards=1,
315319
block_sizes=block_sizes,
316320
save_residuals=True if attention_kernel == "ring" else False,
317321
residual_checkpoint_name=residual_checkpoint_name

0 commit comments

Comments
 (0)