|
| 1 | +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Base functionality for Sparse Flash Attention.""" |
| 16 | + |
| 17 | +import functools |
| 18 | +from typing import Final, NamedTuple, TypeAlias |
| 19 | +import jax |
| 20 | +import jax.numpy as jnp |
| 21 | +import numpy as np |
| 22 | +from . import splash_attention_mask_info as mask_info_lib |
| 23 | + |
| 24 | + |
| 25 | +MaskInfo = mask_info_lib.MaskInfo |
| 26 | + |
| 27 | + |
| 28 | +DEFAULT_MASK_VALUE: Final[float] = -0.7 * float( |
| 29 | + np.finfo(np.dtype("float32")).max |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +class SegmentIds(NamedTuple): |
| 34 | + """SegmentIds for Q and KV sequences. |
| 35 | +
|
| 36 | + SegmentIds are a mechanism to ensure that there is no cross-attention between |
| 37 | + segments (fraction of a sequence) that have been concatenated together into a |
| 38 | + sequence. Each array is a list of ids (integers). Only tokens with the same |
| 39 | + id are allowed to attend to each other. |
| 40 | +
|
| 41 | + The static mask (e.g. causal) is "and-ed" with the segment id mask to form |
| 42 | + the actual attention mask. It is important that the latter does not have any |
| 43 | + all-zero rows (along dimension kv). Otherwise it would result in a invalid |
| 44 | + softmax (the denominator would be 0). |
| 45 | + This condition holds for causal self-attention because in this case segment |
| 46 | + ids form a block diagonal matrix so at least one element in each row is set. |
| 47 | + It is easy to break this condition with non-self-attention configurations. |
| 48 | + Attributes: |
| 49 | + q: segment ids along the Q sequence |
| 50 | + kv: segment ids along the KV sequence |
| 51 | + """ |
| 52 | + |
| 53 | + q: jax.Array | jax.sharding.PartitionSpec # [q_seq_len] |
| 54 | + kv: jax.Array | jax.sharding.PartitionSpec # [kv_seq_len] |
| 55 | + |
| 56 | + |
| 57 | +# Return type of SplashAttention function that implements the custom vjp rule. |
| 58 | +SplashCustomReturnType: TypeAlias = ( |
| 59 | + jax.Array | tuple[jax.Array, dict[str, jax.Array]] |
| 60 | +) |
| 61 | + |
| 62 | +SplashResidualsType = tuple[ |
| 63 | + jax.Array, # q |
| 64 | + jax.Array, # k |
| 65 | + jax.Array, # v |
| 66 | + SegmentIds | None, # segment_ids |
| 67 | + jax.Array | None, # sinks |
| 68 | + jax.Array, # out |
| 69 | + jax.Array, # logsumexp |
| 70 | + MaskInfo | None, # dkv_mask_info |
| 71 | +] |
| 72 | + |
| 73 | + |
| 74 | +def _attention_reference_impl( |
| 75 | + q: jax.Array, |
| 76 | + k: jax.Array, |
| 77 | + v: jax.Array, |
| 78 | + mask: jax.Array, |
| 79 | + segment_ids: SegmentIds | None, |
| 80 | + sinks: jax.Array | None, |
| 81 | + mask_value: float, |
| 82 | + save_residuals: bool, |
| 83 | + attn_logits_soft_cap: float | None, |
| 84 | +) -> SplashCustomReturnType: |
| 85 | + logits = jnp.einsum("sd,td->st", q.astype(jnp.float32), k.astype(jnp.float32)) |
| 86 | + |
| 87 | + if segment_ids is not None: |
| 88 | + mask = jnp.logical_and( |
| 89 | + mask, segment_ids.q[:, None] == segment_ids.kv[None, :] |
| 90 | + ) |
| 91 | + |
| 92 | + if attn_logits_soft_cap is not None: |
| 93 | + logits = jnp.tanh(logits / attn_logits_soft_cap) |
| 94 | + logits = logits * attn_logits_soft_cap |
| 95 | + |
| 96 | + if sinks is not None: |
| 97 | + assert sinks.shape == () # should already be vmapped |
| 98 | + |
| 99 | + logits = jnp.where(mask, logits, mask_value) |
| 100 | + m = logits.max(axis=-1) |
| 101 | + sinks = None if sinks is None else sinks.astype(logits.dtype) |
| 102 | + m = m if sinks is None else jnp.maximum(m, sinks) |
| 103 | + s = jnp.exp(logits - m[..., None]) |
| 104 | + l = s.sum(axis=-1) + (0 if sinks is None else jnp.exp(sinks - m)) |
| 105 | + p = s / l[..., None] |
| 106 | + |
| 107 | + o = jnp.einsum("st,td->sd", p, v.astype(jnp.float32)) |
| 108 | + |
| 109 | + if save_residuals: |
| 110 | + logsumexp = m + jnp.log(l) |
| 111 | + return o, {"logsumexp": logsumexp, "max_logits": m} |
| 112 | + return o |
| 113 | + |
| 114 | + |
| 115 | +def _attention_reference_custom_bwd( |
| 116 | + do, |
| 117 | + q, |
| 118 | + k, |
| 119 | + v, |
| 120 | + mask, |
| 121 | + segment_ids, |
| 122 | + sinks, |
| 123 | + o, |
| 124 | + logsumexp, |
| 125 | + mask_value: float = DEFAULT_MASK_VALUE, |
| 126 | + backward_impl: str = "vanilla", |
| 127 | + attn_logits_soft_cap: float | None = None, |
| 128 | +) -> tuple[jax.Array, jax.Array, jax.Array, None, None, jax.Array | None]: |
| 129 | + uncapped_logits = jnp.einsum( |
| 130 | + "qc,kc->qk", q, k, preferred_element_type=jnp.float32 |
| 131 | + ) |
| 132 | + |
| 133 | + if attn_logits_soft_cap is not None: |
| 134 | + logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap) |
| 135 | + logits = logits * attn_logits_soft_cap |
| 136 | + else: |
| 137 | + logits = uncapped_logits |
| 138 | + |
| 139 | + if segment_ids is not None: |
| 140 | + mask = jnp.logical_and( |
| 141 | + mask, segment_ids.q[:, None] == segment_ids.kv[None, :] |
| 142 | + ) |
| 143 | + logits = jnp.where(mask, logits, mask_value) |
| 144 | + |
| 145 | + p = jnp.exp(logits - logsumexp[..., None]) |
| 146 | + do = do.astype(jnp.float32) # pytype: disable=attribute-error |
| 147 | + dv = jnp.einsum("pt,pd->td", p, do).astype(v.dtype) |
| 148 | + dp = jnp.einsum("pd,td->pt", do, v.astype(jnp.float32)) |
| 149 | + |
| 150 | + # These two ways of computing ds are mathematically equivalent. The first |
| 151 | + # involves reducing over the head_dim dimension and the second involves |
| 152 | + # reducing over a sequence dimension. They tend to produce slightly different |
| 153 | + # numerics. |
| 154 | + if backward_impl == "flash": |
| 155 | + di = jnp.sum(o.astype(jnp.float32) * do, axis=-1)[..., None] |
| 156 | + else: |
| 157 | + di = jnp.einsum("st,st->s", dp, p)[:, None] |
| 158 | + ds = (dp - di) * p |
| 159 | + if attn_logits_soft_cap is not None: |
| 160 | + normalized = uncapped_logits / attn_logits_soft_cap |
| 161 | + d = jnp.tanh(normalized) |
| 162 | + g = ds * (1 - d) |
| 163 | + ds = g + g * d |
| 164 | + dk = jnp.einsum("sd,st->td", q.astype(jnp.float32), ds).astype(k.dtype) |
| 165 | + dq = jnp.einsum("st,td->sd", ds, k.astype(jnp.float32)).astype(q.dtype) |
| 166 | + dsinks = None |
| 167 | + if sinks is not None: |
| 168 | + sinks_exp = -jnp.exp( |
| 169 | + sinks[..., None, None].astype(jnp.float32) |
| 170 | + - logsumexp[..., None].astype(jnp.float32) |
| 171 | + ) |
| 172 | + dsinks = jnp.sum(sinks_exp.astype(o.dtype) * o * do, axis=(-1, -2)) |
| 173 | + return dq, dk, dv, None, None, dsinks |
| 174 | + |
| 175 | + |
| 176 | +@functools.partial( |
| 177 | + jax.jit, |
| 178 | + static_argnames=[ |
| 179 | + "mask_value", |
| 180 | + "save_residuals", |
| 181 | + "attn_logits_soft_cap", |
| 182 | + "is_mqa", |
| 183 | + ], |
| 184 | +) |
| 185 | +def attention_reference( |
| 186 | + q: jax.Array, |
| 187 | + k: jax.Array, |
| 188 | + v: jax.Array, |
| 189 | + mask: jax.Array, |
| 190 | + segment_ids: SegmentIds | None = None, |
| 191 | + sinks: jax.Array | None = None, |
| 192 | + *, |
| 193 | + is_mqa: bool, |
| 194 | + mask_value: float = DEFAULT_MASK_VALUE, |
| 195 | + save_residuals: bool = False, |
| 196 | + attn_logits_soft_cap: float | None = None, |
| 197 | +): |
| 198 | + """A JIT-compiled reference implementation of attention, handles MQA and MHA.""" |
| 199 | + attn_impl = functools.partial( |
| 200 | + _attention_reference_impl, |
| 201 | + mask_value=mask_value, |
| 202 | + save_residuals=save_residuals, |
| 203 | + attn_logits_soft_cap=attn_logits_soft_cap, |
| 204 | + ) |
| 205 | + |
| 206 | + if is_mqa: |
| 207 | + func = jax.vmap(attn_impl, in_axes=(0, None, None, None, None, 0)) |
| 208 | + else: |
| 209 | + # In grouped attention (1 < num_kv_heads && num_kv_heads < num_q_heads). |
| 210 | + # We interleave the KV heads across the Q heads. |
| 211 | + # For example: for 8 Q heads and 4 KV heads: |
| 212 | + # Q head [0, 1] see KV head 0 |
| 213 | + # Q head [2, 3] see KV head 1 |
| 214 | + # Q head [4, 5] see KV head 2 |
| 215 | + # Q head [6, 7] see KV head 3 |
| 216 | + |
| 217 | + kv_heads, q_heads = k.shape[0], q.shape[0] |
| 218 | + assert q_heads % kv_heads == 0 |
| 219 | + |
| 220 | + if kv_heads < q_heads: |
| 221 | + # Repeat K and V heads to match the number of Q heads. |
| 222 | + q_heads_per_kv = q_heads // kv_heads |
| 223 | + k = jnp.repeat(k, repeats=q_heads_per_kv, axis=0) |
| 224 | + v = jnp.repeat(v, repeats=q_heads_per_kv, axis=0) |
| 225 | + |
| 226 | + func = jax.vmap(attn_impl, in_axes=(0, 0, 0, None, None, 0)) |
| 227 | + |
| 228 | + out = func(q, k, v, mask, segment_ids, sinks) |
| 229 | + return out |
| 230 | + |
| 231 | + |
| 232 | +@functools.partial( |
| 233 | + jax.jit, static_argnames=["is_mqa", "backward_impl", "attn_logits_soft_cap"] |
| 234 | +) |
| 235 | +def attention_reference_vjp( |
| 236 | + do, |
| 237 | + q, |
| 238 | + k, |
| 239 | + v, |
| 240 | + mask, |
| 241 | + segment_ids, |
| 242 | + sinks, |
| 243 | + o, |
| 244 | + logsumexp, |
| 245 | + *, |
| 246 | + is_mqa: bool, |
| 247 | + backward_impl: str = "vanilla", |
| 248 | + attn_logits_soft_cap: float | None = None, |
| 249 | +): |
| 250 | + """Wrapper for backward reference that handles GQA/MQA broadcasting and reduction.""" |
| 251 | + bwd = functools.partial( |
| 252 | + _attention_reference_custom_bwd, |
| 253 | + backward_impl=backward_impl, |
| 254 | + attn_logits_soft_cap=attn_logits_soft_cap, |
| 255 | + ) |
| 256 | + |
| 257 | + num_q_heads = q.shape[0] |
| 258 | + num_kv_heads = 1 if is_mqa else k.shape[0] |
| 259 | + |
| 260 | + is_grouped = not is_mqa and num_kv_heads < num_q_heads |
| 261 | + assert num_q_heads % num_kv_heads == 0 |
| 262 | + head_multiplier = num_q_heads // num_kv_heads |
| 263 | + if is_mqa: |
| 264 | + bwd = jax.vmap(bwd, in_axes=(0, 0, None, None, None, None, 0, 0, 0)) |
| 265 | + else: |
| 266 | + bwd = jax.vmap(bwd, in_axes=(0, 0, 0, 0, None, None, 0, 0, 0)) |
| 267 | + # Interleave the KV heads to match the corresponding Q heads. |
| 268 | + if is_grouped: |
| 269 | + k = jnp.repeat(k, head_multiplier, axis=0) |
| 270 | + v = jnp.repeat(v, head_multiplier, axis=0) |
| 271 | + |
| 272 | + dq, dk, dv, _, _, dsinks = bwd( |
| 273 | + do, q, k, v, mask, segment_ids, sinks, o, logsumexp |
| 274 | + ) |
| 275 | + |
| 276 | + if is_mqa: |
| 277 | + dk, dv = dk.sum(axis=0), dv.sum(axis=0) |
| 278 | + elif is_grouped: |
| 279 | + # Perform the sum reduction across the head_multiplier dimension only. |
| 280 | + # So that the output still has KV heads. |
| 281 | + dk = dk.reshape(num_kv_heads, head_multiplier, *dk.shape[1:]) |
| 282 | + dv = dv.reshape(num_kv_heads, head_multiplier, *dv.shape[1:]) |
| 283 | + dk, dv = dk.sum(axis=1), dv.sum(axis=1) |
| 284 | + |
| 285 | + return dq, dk, dv, dsinks |
0 commit comments