Skip to content

Commit 0e60bbb

Browse files
committed
Merge remote-tracking branch 'origin/kunjanp-ring-attention' into elisatsai_ring_attention
2 parents c236d56 + 9bcd458 commit 0e60bbb

15 files changed

Lines changed: 7182 additions & 3 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
64+
attention: 'tokamax_ring'
6465
flash_min_seq_length: 0
6566

6667
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/kernels/__init__.py

Whitespace-only changes.
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Base functionality for Sparse Flash Attention."""
16+
17+
import functools
18+
from typing import Final, NamedTuple, TypeAlias
19+
import jax
20+
import jax.numpy as jnp
21+
import numpy as np
22+
from . import splash_attention_mask_info as mask_info_lib
23+
24+
25+
MaskInfo = mask_info_lib.MaskInfo
26+
27+
28+
DEFAULT_MASK_VALUE: Final[float] = -0.7 * float(
29+
np.finfo(np.dtype("float32")).max
30+
)
31+
32+
33+
class SegmentIds(NamedTuple):
34+
"""SegmentIds for Q and KV sequences.
35+
36+
SegmentIds are a mechanism to ensure that there is no cross-attention between
37+
segments (fraction of a sequence) that have been concatenated together into a
38+
sequence. Each array is a list of ids (integers). Only tokens with the same
39+
id are allowed to attend to each other.
40+
41+
The static mask (e.g. causal) is "and-ed" with the segment id mask to form
42+
the actual attention mask. It is important that the latter does not have any
43+
all-zero rows (along dimension kv). Otherwise it would result in a invalid
44+
softmax (the denominator would be 0).
45+
This condition holds for causal self-attention because in this case segment
46+
ids form a block diagonal matrix so at least one element in each row is set.
47+
It is easy to break this condition with non-self-attention configurations.
48+
Attributes:
49+
q: segment ids along the Q sequence
50+
kv: segment ids along the KV sequence
51+
"""
52+
53+
q: jax.Array | jax.sharding.PartitionSpec # [q_seq_len]
54+
kv: jax.Array | jax.sharding.PartitionSpec # [kv_seq_len]
55+
56+
57+
# Return type of SplashAttention function that implements the custom vjp rule.
58+
SplashCustomReturnType: TypeAlias = (
59+
jax.Array | tuple[jax.Array, dict[str, jax.Array]]
60+
)
61+
62+
SplashResidualsType = tuple[
63+
jax.Array, # q
64+
jax.Array, # k
65+
jax.Array, # v
66+
SegmentIds | None, # segment_ids
67+
jax.Array | None, # sinks
68+
jax.Array, # out
69+
jax.Array, # logsumexp
70+
MaskInfo | None, # dkv_mask_info
71+
]
72+
73+
74+
def _attention_reference_impl(
75+
q: jax.Array,
76+
k: jax.Array,
77+
v: jax.Array,
78+
mask: jax.Array,
79+
segment_ids: SegmentIds | None,
80+
sinks: jax.Array | None,
81+
mask_value: float,
82+
save_residuals: bool,
83+
attn_logits_soft_cap: float | None,
84+
) -> SplashCustomReturnType:
85+
logits = jnp.einsum("sd,td->st", q.astype(jnp.float32), k.astype(jnp.float32))
86+
87+
if segment_ids is not None:
88+
mask = jnp.logical_and(
89+
mask, segment_ids.q[:, None] == segment_ids.kv[None, :]
90+
)
91+
92+
if attn_logits_soft_cap is not None:
93+
logits = jnp.tanh(logits / attn_logits_soft_cap)
94+
logits = logits * attn_logits_soft_cap
95+
96+
if sinks is not None:
97+
assert sinks.shape == () # should already be vmapped
98+
99+
logits = jnp.where(mask, logits, mask_value)
100+
m = logits.max(axis=-1)
101+
sinks = None if sinks is None else sinks.astype(logits.dtype)
102+
m = m if sinks is None else jnp.maximum(m, sinks)
103+
s = jnp.exp(logits - m[..., None])
104+
l = s.sum(axis=-1) + (0 if sinks is None else jnp.exp(sinks - m))
105+
p = s / l[..., None]
106+
107+
o = jnp.einsum("st,td->sd", p, v.astype(jnp.float32))
108+
109+
if save_residuals:
110+
logsumexp = m + jnp.log(l)
111+
return o, {"logsumexp": logsumexp, "max_logits": m}
112+
return o
113+
114+
115+
def _attention_reference_custom_bwd(
116+
do,
117+
q,
118+
k,
119+
v,
120+
mask,
121+
segment_ids,
122+
sinks,
123+
o,
124+
logsumexp,
125+
mask_value: float = DEFAULT_MASK_VALUE,
126+
backward_impl: str = "vanilla",
127+
attn_logits_soft_cap: float | None = None,
128+
) -> tuple[jax.Array, jax.Array, jax.Array, None, None, jax.Array | None]:
129+
uncapped_logits = jnp.einsum(
130+
"qc,kc->qk", q, k, preferred_element_type=jnp.float32
131+
)
132+
133+
if attn_logits_soft_cap is not None:
134+
logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap)
135+
logits = logits * attn_logits_soft_cap
136+
else:
137+
logits = uncapped_logits
138+
139+
if segment_ids is not None:
140+
mask = jnp.logical_and(
141+
mask, segment_ids.q[:, None] == segment_ids.kv[None, :]
142+
)
143+
logits = jnp.where(mask, logits, mask_value)
144+
145+
p = jnp.exp(logits - logsumexp[..., None])
146+
do = do.astype(jnp.float32) # pytype: disable=attribute-error
147+
dv = jnp.einsum("pt,pd->td", p, do).astype(v.dtype)
148+
dp = jnp.einsum("pd,td->pt", do, v.astype(jnp.float32))
149+
150+
# These two ways of computing ds are mathematically equivalent. The first
151+
# involves reducing over the head_dim dimension and the second involves
152+
# reducing over a sequence dimension. They tend to produce slightly different
153+
# numerics.
154+
if backward_impl == "flash":
155+
di = jnp.sum(o.astype(jnp.float32) * do, axis=-1)[..., None]
156+
else:
157+
di = jnp.einsum("st,st->s", dp, p)[:, None]
158+
ds = (dp - di) * p
159+
if attn_logits_soft_cap is not None:
160+
normalized = uncapped_logits / attn_logits_soft_cap
161+
d = jnp.tanh(normalized)
162+
g = ds * (1 - d)
163+
ds = g + g * d
164+
dk = jnp.einsum("sd,st->td", q.astype(jnp.float32), ds).astype(k.dtype)
165+
dq = jnp.einsum("st,td->sd", ds, k.astype(jnp.float32)).astype(q.dtype)
166+
dsinks = None
167+
if sinks is not None:
168+
sinks_exp = -jnp.exp(
169+
sinks[..., None, None].astype(jnp.float32)
170+
- logsumexp[..., None].astype(jnp.float32)
171+
)
172+
dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2))
173+
return dq, dk, dv, None, None, dsinks
174+
175+
176+
@functools.partial(
177+
jax.jit,
178+
static_argnames=[
179+
"mask_value",
180+
"save_residuals",
181+
"attn_logits_soft_cap",
182+
"is_mqa",
183+
],
184+
)
185+
def attention_reference(
186+
q: jax.Array,
187+
k: jax.Array,
188+
v: jax.Array,
189+
mask: jax.Array,
190+
segment_ids: SegmentIds | None = None,
191+
sinks: jax.Array | None = None,
192+
*,
193+
is_mqa: bool,
194+
mask_value: float = DEFAULT_MASK_VALUE,
195+
save_residuals: bool = False,
196+
attn_logits_soft_cap: float | None = None,
197+
):
198+
"""A JIT-compiled reference implementation of attention, handles MQA and MHA."""
199+
attn_impl = functools.partial(
200+
_attention_reference_impl,
201+
mask_value=mask_value,
202+
save_residuals=save_residuals,
203+
attn_logits_soft_cap=attn_logits_soft_cap,
204+
)
205+
206+
if is_mqa:
207+
func = jax.vmap(attn_impl, in_axes=(0, None, None, None, None, 0))
208+
else:
209+
# In grouped attention (1 < num_kv_heads && num_kv_heads < num_q_heads).
210+
# We interleave the KV heads across the Q heads.
211+
# For example: for 8 Q heads and 4 KV heads:
212+
# Q head [0, 1] see KV head 0
213+
# Q head [2, 3] see KV head 1
214+
# Q head [4, 5] see KV head 2
215+
# Q head [6, 7] see KV head 3
216+
217+
kv_heads, q_heads = k.shape[0], q.shape[0]
218+
assert q_heads % kv_heads == 0
219+
220+
if kv_heads < q_heads:
221+
# Repeat K and V heads to match the number of Q heads.
222+
q_heads_per_kv = q_heads // kv_heads
223+
k = jnp.repeat(k, repeats=q_heads_per_kv, axis=0)
224+
v = jnp.repeat(v, repeats=q_heads_per_kv, axis=0)
225+
226+
func = jax.vmap(attn_impl, in_axes=(0, 0, 0, None, None, 0))
227+
228+
out = func(q, k, v, mask, segment_ids, sinks)
229+
return out
230+
231+
232+
@functools.partial(
233+
jax.jit, static_argnames=["is_mqa", "backward_impl", "attn_logits_soft_cap"]
234+
)
235+
def attention_reference_vjp(
236+
do,
237+
q,
238+
k,
239+
v,
240+
mask,
241+
segment_ids,
242+
sinks,
243+
o,
244+
logsumexp,
245+
*,
246+
is_mqa: bool,
247+
backward_impl: str = "vanilla",
248+
attn_logits_soft_cap: float | None = None,
249+
):
250+
"""Wrapper for backward reference that handles GQA/MQA broadcasting and reduction."""
251+
bwd = functools.partial(
252+
_attention_reference_custom_bwd,
253+
backward_impl=backward_impl,
254+
attn_logits_soft_cap=attn_logits_soft_cap,
255+
)
256+
257+
num_q_heads = q.shape[0]
258+
num_kv_heads = 1 if is_mqa else k.shape[0]
259+
260+
is_grouped = not is_mqa and num_kv_heads < num_q_heads
261+
assert num_q_heads % num_kv_heads == 0
262+
head_multiplier = num_q_heads // num_kv_heads
263+
if is_mqa:
264+
bwd = jax.vmap(bwd, in_axes=(0, 0, None, None, None, None, 0, 0, 0))
265+
else:
266+
bwd = jax.vmap(bwd, in_axes=(0, 0, 0, 0, None, None, 0, 0, 0))
267+
# Interleave the KV heads to match the corresponding Q heads.
268+
if is_grouped:
269+
k = jnp.repeat(k, head_multiplier, axis=0)
270+
v = jnp.repeat(v, head_multiplier, axis=0)
271+
272+
dq, dk, dv, _, _, dsinks = bwd(
273+
do, q, k, v, mask, segment_ids, sinks, o, logsumexp
274+
)
275+
276+
if is_mqa:
277+
dk, dv = dk.sum(axis=0), dv.sum(axis=0)
278+
elif is_grouped:
279+
# Perform the sum reduction across the head_multiplier dimension only.
280+
# So that the output still has KV heads.
281+
dk = dk.reshape(num_kv_heads, head_multiplier, *dk.shape[1:])
282+
dv = dv.reshape(num_kv_heads, head_multiplier, *dv.shape[1:])
283+
dk, dv = dk.sum(axis=1), dv.sum(axis=1)
284+
285+
return dq, dk, dv, dsinks
170 KB
Binary file not shown.

0 commit comments

Comments
 (0)