Skip to content

Commit c7edfb0

Browse files
committed
reduce padding by computing it inside sharded qkvs.
1 parent f93d261 commit c7edfb0

1 file changed

Lines changed: 34 additions & 25 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,23 @@ def _unflatten_heads(tensor, heads):
112112
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
113113
return tensor
114114

115-
116-
def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
115+
def _reshape_data_for_flash(tensor, heads):
117116
"""
118117
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
119118
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
120119
blocks is divisible by the number of shards.
121120
"""
122121
if tensor.ndim != 4:
123122
tensor = _unflatten_heads(tensor, heads)
123+
return tensor
124+
125+
def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
126+
"""
127+
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
128+
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
129+
blocks is divisible by the number of shards.
130+
"""
131+
tensor = _reshape_data_for_flash(tensor, heads)
124132

125133
# Pad head_dim to 128 if less than that.
126134
kv_size = tensor.shape[-1]
@@ -148,8 +156,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
148156

149157
if kv_size < 128 or seq_len_pad != 0:
150158
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
151-
padded_tensor = jnp.pad(tensor, npad)
152-
tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None))
159+
tensor = jnp.pad(tensor, npad)
153160

154161
return tensor, kv_size, seq_len
155162

@@ -166,11 +173,13 @@ def _tpu_flash_attention(
166173
dtype: jnp.dtype = jnp.float32,
167174
) -> jax.Array:
168175
"""TPU Flash Attention"""
176+
169177
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
170-
# Cross-attention where kv dims are much smaller due to encoder_hidden_states.
171-
# If kv seq_len is padded too much, it causes issues in attention calculations.
178+
# This is the case for cross-attn.
172179
if key.shape[1] != query.shape[1]:
180+
assert key.shape[1] % 128 == 0
173181
kv_max_block_size = key.shape[1]
182+
#q_max_block_size = kv_max_block_size
174183
else:
175184
kv_max_block_size = q_max_block_size
176185
if flash_block_sizes:
@@ -186,43 +195,44 @@ def _tpu_flash_attention(
186195
block_q_dq=min(q_max_block_size, query.shape[2]),
187196
block_kv_dq=min(kv_max_block_size, query.shape[2]),
188197
)
189-
190-
num_fsdp_shards = mesh.shape["fsdp"]
191-
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
192-
key, _, key_seq_len = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
193-
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
198+
199+
query = _reshape_data_for_flash(query, heads)
200+
key = _reshape_data_for_flash(key, heads)
201+
value = _reshape_data_for_flash(value, heads)
194202
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
195203
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
196204

197-
# To only attend to non-padded tokens.
198-
segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, LENGTH))
199-
segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, KV_LENGTH))
200-
q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0)
201-
q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0]))
202-
kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0)
203-
kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0]))
204-
205205
@functools.partial(
206206
shard_map.shard_map,
207207
mesh=mesh,
208-
in_specs=(q_axis_names, kv_axis_names, kv_axis_names, segment_axis_names_q, segment_axis_names_kv),
208+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
209209
out_specs=q_axis_names,
210210
check_rep=False,
211211
)
212-
def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids):
212+
def wrap_flash_attention(query, key, value):
213+
214+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q)
215+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute)
216+
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute)
217+
213218
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
214219
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
220+
q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0)
221+
q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0]))
222+
kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0)
223+
kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0]))
224+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
225+
215226
# make_splash_mha is wrapped around shardmap and seq and head is already
216227
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
217-
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
218228
splash_kernel = splash_attention_kernel.make_splash_mha(
219229
mask=multi_head_mask,
220230
head_shards=1, # the sizes of the axis is sharding over heads
221231
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
222232
block_sizes=block_sizes,
223233
)
224234
attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=segment_ids)
225-
return attention_output
235+
return attention_output[:,:,:query_seq_len,:kv_size]
226236

227237
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
228238
# This warning might show up when doing model eval for example, when calculating model flops
@@ -232,8 +242,7 @@ def wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids):
232242
"Warning, batch dimension should be shardable among the devices in data and fsdp"
233243
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
234244
)
235-
x = wrap_flash_attention(query, key, value, q_segment_ids, kv_segment_ids)
236-
x = x[:, :, :query_seq_len, :kv_size]
245+
x = wrap_flash_attention(query, key, value)
237246
x = _reshape_heads_to_head_dim(x)
238247

239248
return x

0 commit comments

Comments
 (0)