Skip to content

Commit ec61456

Browse files
committed
scanned ring attn.
1 parent c7edfb0 commit ec61456

1 file changed

Lines changed: 45 additions & 5 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _tpu_flash_attention(
195195
block_q_dq=min(q_max_block_size, query.shape[2]),
196196
block_kv_dq=min(kv_max_block_size, query.shape[2]),
197197
)
198-
198+
num_fsdp_shards = mesh.shape["fsdp"]
199199
query = _reshape_data_for_flash(query, heads)
200200
key = _reshape_data_for_flash(key, heads)
201201
value = _reshape_data_for_flash(value, heads)
@@ -218,9 +218,7 @@ def wrap_flash_attention(query, key, value):
218218
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
219219
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
220220
q_segment_ids = jnp.where(jnp.arange(query.shape[2]) < query_seq_len, 1, 0)
221-
q_segment_ids = jnp.broadcast_to(q_segment_ids, (query.shape[0], q_segment_ids.shape[0]))
222221
kv_segment_ids = jnp.where(jnp.arange(key.shape[2]) < key_seq_len, 1, 0)
223-
kv_segment_ids = jnp.broadcast_to(kv_segment_ids, (query.shape[0], kv_segment_ids.shape[0]))
224222
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
225223

226224
# make_splash_mha is wrapped around shardmap and seq and head is already
@@ -230,9 +228,51 @@ def wrap_flash_attention(query, key, value):
230228
head_shards=1, # the sizes of the axis is sharding over heads
231229
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
232230
block_sizes=block_sizes,
231+
save_residuals=True
233232
)
234-
attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=segment_ids)
235-
return attention_output[:,:,:query_seq_len,:kv_size]
233+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0,0,0, None))
234+
235+
def ring_scan_body(carry, _):
236+
m, l, o, k_current, v_current = carry
237+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
238+
k_next = jax.lax.ppermute(k_current, axis_name='fsdp', perm=perm)
239+
v_next = jax.lax.ppermute(v_current, axis_name='fsdp', perm=perm)
240+
241+
out_chunk, (lse_chunk,) = vmapped_splash(
242+
query, k_current, v_current, segment_ids
243+
)
244+
245+
m_chunk = lse_chunk.astype(jnp.float32)
246+
m_old = m
247+
m = jnp.maximum(m_old, m_chunk)
248+
249+
exp_m_diff = jnp.exp(m_old - m)
250+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
251+
252+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
253+
o = o * exp_m_diff[..., None]
254+
o += (exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32))
255+
256+
# Return the updated state for the next iteration
257+
return (m, l, o, k_next, v_next), None
258+
259+
lse_shape = query.shape[:-1]
260+
m_init = jnp.full(lse_shape, -jnp.inf, dtype=jnp.float32)
261+
l_init = jnp.zeros(lse_shape, dtype=jnp.float32)
262+
o_init = jnp.zeros_like(query, dtype=jnp.float32)
263+
264+
initial_carry = (m_init, l_init, o_init, key, value)
265+
266+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(
267+
ring_scan_body,
268+
initial_carry,
269+
None,
270+
length=num_fsdp_shards
271+
)
272+
273+
attention_output = o_final / l_final[..., None]
274+
275+
return attention_output[:,:,:query_seq_len,:kv_size].astype(query.dtype)
236276

237277
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
238278
# This warning might show up when doing model eval for example, when calculating model flops

0 commit comments

Comments
 (0)