Skip to content

Commit f93d261

Browse files
committed
adds segment ids for masking.
1 parent aad9839 commit f93d261

1 file changed

Lines changed: 14 additions & 9 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,34 +189,39 @@ def _tpu_flash_attention(
189189

190190
num_fsdp_shards = mesh.shape["fsdp"]
191191
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
192-
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
192+
key, _, key_seq_len = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
193193
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
194194
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
195195
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
196196

197+
# To only attend to non-padded tokens.
198+
segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, LENGTH))
199+
segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, KV_LENGTH))
200+
q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0)
201+
q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0]))
202+
kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0)
203+
kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0]))
204+
197205
@functools.partial(
198206
shard_map.shard_map,
199207
mesh=mesh,
200-
in_specs=(
201-
q_axis_names,
202-
kv_axis_names,
203-
kv_axis_names,
204-
),
208+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names, segment_axis_names_q, segment_axis_names_kv),
205209
out_specs=q_axis_names,
206210
check_rep=False,
207211
)
208-
def wrap_flash_attention(query, key, value):
212+
def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids):
209213
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210214
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
211215
# make_splash_mha is wrapped around shardmap and seq and head is already
212216
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
217+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
213218
splash_kernel = splash_attention_kernel.make_splash_mha(
214219
mask=multi_head_mask,
215220
head_shards=1, # the sizes of the axis is sharding over heads
216221
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
217222
block_sizes=block_sizes,
218223
)
219-
attention_output = jax.vmap(splash_kernel)(query, key, value)
224+
attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=segment_ids)
220225
return attention_output
221226

222227
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
@@ -227,7 +232,7 @@ def wrap_flash_attention(query, key, value):
227232
"Warning, batch dimension should be shardable among the devices in data and fsdp"
228233
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
229234
)
230-
x = wrap_flash_attention(query, key, value)
235+
x = wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids)
231236
x = x[:, :, :query_seq_len, :kv_size]
232237
x = _reshape_heads_to_head_dim(x)
233238

0 commit comments

Comments
 (0)