Skip to content

Commit d5f28aa

Browse files
committed
ring attention for inference.
1 parent 2809d4e commit d5f28aa

2 files changed

Lines changed: 46 additions & 5 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ logical_axis_rules: [
132132
['batch', 'data'],
133133
['activation_batch', 'data'],
134134
['activation_length', 'fsdp'],
135-
135+
['activation_kv_length', 'fsdp'],
136136
['activation_heads', 'tensor'],
137137
['mlp','tensor'],
138138
['embed','fsdp'],

src/maxdiffusion/models/attention_flax.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def _tpu_flash_attention(
188188
)
189189

190190
num_fsdp_shards = mesh.shape["fsdp"]
191-
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
192-
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
191+
query, kv_size, original_query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
192+
key, _, original_key_seq_len = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
193193
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
194194
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
195195
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
@@ -206,6 +206,10 @@ def _tpu_flash_attention(
206206
check_rep=False,
207207
)
208208
def wrap_flash_attention(query, key, value):
209+
jax.debug.print("query.shape: {x}", x=query.shape)
210+
jax.debug.print("key.shape: {x}", x=key.shape)
211+
jax.debug.print("value.shape: {x}", x=value.shape)
212+
209213
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210214
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
211215
# make_splash_mha is wrapped around shardmap and seq and head is already
@@ -215,8 +219,45 @@ def wrap_flash_attention(query, key, value):
215219
head_shards=1, # the sizes of the axis is sharding over heads
216220
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
217221
block_sizes=block_sizes,
222+
save_residuals=True
218223
)
219-
attention_output = jax.vmap(splash_kernel)(query, key, value)
224+
out, (lse,) = jax.vmap(splash_kernel)(query, key, value)
225+
#breakpoint()
226+
m = lse.astype(jnp.float32)
227+
l = jnp.exp(lse.astype(jnp.float32) - m)
228+
o = out.astype(jnp.float32) * l[..., None]
229+
230+
k_ring = key
231+
v_ring = value
232+
233+
for i in range(1, num_fsdp_shards):
234+
k_ring = jax.lax.ppermute(k_ring, axis_name='fsdp', perm=[(j, (j+1) % num_fsdp_shards) for j in range(num_fsdp_shards)])
235+
v_ring = jax.lax.ppermute(v_ring, axis_name='fsdp', perm=[(j, (j+1) % num_fsdp_shards) for j in range(num_fsdp_shards)])
236+
237+
out_chunk, (lse_chunk,) = jax.vmap(splash_kernel)(query, k_ring, v_ring)
238+
m_chunk = lse_chunk.astype(jnp.float32)
239+
p_chunk = jnp.exp(lse_chunk.astype(jnp.float32) - m_chunk)
240+
241+
m_new = jnp.maximum(m, m_chunk)
242+
243+
l = l * jnp.exp(m - m_new)
244+
p_chunk_rescaled = p_chunk * jnp.exp(m_chunk - m_new)
245+
246+
l_new = l + p_chunk_rescaled
247+
248+
o = o * jnp.exp(m - m_new)[..., None]
249+
o += p_chunk_rescaled[..., None] * out_chunk
250+
251+
m = m_new
252+
l = l_new
253+
jax.debug.print("Loop {i}: max(m)={m_max}, max(l)={l_max}, max(o)={o_max}",
254+
i=i,
255+
m_max=m.max(),
256+
l_max=l.max(),
257+
o_max=o.max())
258+
259+
attention_output = o / l[..., None]
260+
220261
return attention_output
221262

222263
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
@@ -228,7 +269,7 @@ def wrap_flash_attention(query, key, value):
228269
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
229270
)
230271
x = wrap_flash_attention(query, key, value)
231-
x = x[:, :, :query_seq_len, :kv_size]
272+
x = x[:, :, :original_query_seq_len, :kv_size]
232273
x = _reshape_heads_to_head_dim(x)
233274

234275
return x

0 commit comments

Comments
 (0)