Skip to content

Commit 850a690

Browse files
committed
make mask selectively
1 parent 42cbb0e commit 850a690

2 files changed

Lines changed: 21 additions & 19 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ venv/
104104
ENV/
105105
env.bak/
106106
venv.bak/
107+
.history
107108

108109
# Spyder project settings
109110
.spyderproject

src/maxdiffusion/models/attention_flax.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -226,25 +226,26 @@ def wrap_flash_attention(query, key, value):
226226
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv)
227227
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv)
228228

229-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
230-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
231-
232-
q_padded_len = query.shape[2]
233-
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
234-
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
235-
236-
kv_padded_len = key.shape[2]
237-
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
238-
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
239-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
240-
splash_kernel = splash_attention_kernel.make_splash_mha(
241-
mask=multi_head_mask,
242-
head_shards=1, # the sizes of the axis is sharding over heads
243-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
244-
block_sizes=block_sizes,
245-
save_residuals=True if attention_kernel == "ring" else False,
246-
)
247-
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None), out_axes=0)
229+
if attention_kernel == "flash":
230+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
231+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
232+
233+
q_padded_len = query.shape[2]
234+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
235+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
236+
237+
kv_padded_len = key.shape[2]
238+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
239+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
240+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
241+
splash_kernel = splash_attention_kernel.make_splash_mha(
242+
mask=multi_head_mask,
243+
head_shards=1, # the sizes of the axis is sharding over heads
244+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
245+
block_sizes=block_sizes,
246+
save_residuals=True if attention_kernel == "ring" else False,
247+
)
248+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None), out_axes=0)
248249

249250
if attention_kernel == "flash":
250251
# attention_output = vmapped_splash(query, key, value, segment_ids)

0 commit comments

Comments
 (0)