Skip to content

Commit 4b722f7

Browse files
Merge pull request #2971 from AI-Hypercomputer:rbierneni-qwen3-next-caching
PiperOrigin-RevId: 879149561
2 parents 72e96f5 + ecd0f7a commit 4b722f7

5 files changed

Lines changed: 296 additions & 29 deletions

File tree

src/maxtext/inference/kvcache.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def kv_cache_as_linen(
230230
)
231231

232232

233-
class KVCache(nnx.Module):
233+
class BaseCache(nnx.Module):
234+
"""Abstract base class for Caches."""
235+
236+
237+
class KVCache(BaseCache):
234238
"""Implementation of the KVCache."""
235239

236240
def __init__(
@@ -290,6 +294,7 @@ def __init__(
290294
use_chunked_prefill: Whether to use chunked prefill.
291295
rngs: The random number generators for initialization.
292296
"""
297+
super().__init__()
293298
self.max_prefill_length = max_prefill_length
294299
self.max_target_length = max_target_length
295300
self.batch = batch
@@ -844,6 +849,76 @@ def __call__(
844849
raise ValueError(f"Model Mode isn't supported! {model_mode=}")
845850

846851

852+
class GatedDeltaNetCache(BaseCache):
853+
"""Cache for Linear Attention (Gated Delta Net).
854+
855+
Stores the fixed-size recurrent state and the sliding window state for convolution.
856+
"""
857+
858+
def __init__(
859+
self,
860+
batch: int,
861+
num_heads: int,
862+
k_head_dim: int,
863+
v_head_dim: int,
864+
conv_kernel_size: int,
865+
conv_dim: int,
866+
dtype: DType,
867+
cache_batch_axis_name: str = CACHE_BATCH,
868+
cache_heads_axis_name: str = CACHE_HEADS,
869+
):
870+
super().__init__()
871+
self.batch = batch
872+
self.dtype = dtype
873+
874+
# 1. Recurrent State (S) for the Delta Rule
875+
# Shape: [Batch, Heads, K_Dim, V_Dim]
876+
# We maintain the running state matrix.
877+
self.recurrent_state = nnx.Cache(
878+
jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype),
879+
# Sharding: Batch, Heads, None (K), None (V)
880+
sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None),
881+
)
882+
883+
# 2. Convolution State for the 1D Conv
884+
# Shape: [Batch, Kernel_Size - 1, Conv_Dim]
885+
# We store the last (K-1) inputs to perform the sliding window conv during decoding.
886+
self.conv_state = nnx.Cache(
887+
jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype),
888+
# Sharding: Batch, None (Time), None (Dim)
889+
sharding=(cache_batch_axis_name, None, None),
890+
)
891+
892+
def __call__(self):
893+
"""Returns the cache variables for the layer to use."""
894+
return self
895+
896+
897+
def gated_delta_net_cache_as_linen(
898+
*,
899+
batch: int,
900+
num_heads: int,
901+
head_dim: int,
902+
conv_kernel_size: int,
903+
conv_dim: int,
904+
dtype: DType,
905+
name: str | None = None,
906+
):
907+
"""Initializes the GatedDeltaNetCache and returns it as a Linen module."""
908+
return nnx_wrappers.to_linen(
909+
GatedDeltaNetCache,
910+
batch=batch,
911+
num_heads=num_heads,
912+
head_dim=head_dim,
913+
conv_kernel_size=conv_kernel_size,
914+
conv_dim=conv_dim,
915+
dtype=dtype,
916+
metadata_fn=variable_to_logically_partitioned,
917+
name=name,
918+
abstract_init=False,
919+
)
920+
921+
847922
def mla_kv_cache_as_linen(
848923
*,
849924
max_prefill_length: int,

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,9 @@ def copy(path, partial_cache, full_cache, annotations):
11461146
"cached_prefill_value_scale",
11471147
]:
11481148
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
1149+
elif path_key in ["recurrent_state", "conv_state"]:
1150+
# Direct update for fixed-size linear attention states
1151+
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
11491152
else:
11501153
raise ValueError(f"We don't have a strategy for inserting {path_key}")
11511154

@@ -1258,6 +1261,10 @@ def copy(path, partial_cache, full_cache, annotations):
12581261
"cached_prefill_value_scale",
12591262
]:
12601263
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
1264+
elif path_key in ["recurrent_state", "conv_state"]:
1265+
# For linear attention, the state is fixed size. We simply copy the result
1266+
# from the prefill step (partial_cache) into the decode state (full_cache).
1267+
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
12611268
else:
12621269
raise ValueError(f"We don't have a strategy for inserting {path_key}")
12631270

@@ -1447,6 +1454,15 @@ def copy(path, partial_cache, full_cache, annotations):
14471454
partial_cache = jax.lax.dynamic_slice(partial_cache, start_indices, slice_size)
14481455

14491456
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
1457+
elif path_key in ["recurrent_state", "conv_state"]:
1458+
# SSM states are the "final state" after prefill, so we just overwrite the slot.
1459+
# We don't need to slice by sequence length like we do for KV cache.
1460+
if num_prompts > 1:
1461+
raise NotImplementedError(
1462+
"Packed prefill is currently incompatible with linear attention states (GDN). "
1463+
"Prompt memory will bleed into adjacent prompts. Please disable packed prefill."
1464+
)
1465+
return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
14501466
else:
14511467
raise ValueError(f"We don't have a strategy for inserting {path_key}")
14521468

@@ -1660,7 +1676,13 @@ def initialize():
16601676
def is_lp(k):
16611677
return isinstance(k, flax.linen.spmd.LogicallyPartitioned)
16621678

1663-
self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x: tuple(x.names), cache, is_leaf=is_lp)
1679+
self.kv_cache_annotations_named = jax.tree_util.tree_map(
1680+
lambda x: tuple(x.logical_axes)
1681+
if hasattr(x, "logical_axes")
1682+
else (tuple(x.names) if hasattr(x, "names") else ()),
1683+
cache,
1684+
is_leaf=is_lp,
1685+
)
16641686
zeroed = max_utils.unbox_logicallypartioned(init_state)
16651687
return zeroed
16661688

src/maxtext/layers/decoders.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -980,15 +980,22 @@ def __call__(
980980
}
981981
if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
982982
layer_kwargs = {"layer_idx": lyr}
983+
kv_cache = None
984+
if kv_caches is not None and cfg.decoder_block != DecoderBlockType.QWEN3_NEXT:
985+
kv_cache = kv_caches[lyr]
986+
elif kv_caches is not None and cfg.decoder_block == DecoderBlockType.QWEN3_NEXT:
987+
# For Qwen3Next, kv_caches is a dictionary of lists of caches.
988+
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
989+
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])
990+
983991
if cfg.decoder_block == DecoderBlockType.GPT_OSS:
984992
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
985993
if cfg.decoder_block == DecoderBlockType.OLMO3:
986994
layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)}
987995
layer = RemattedBlockLayer(
988996
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
989997
)
990-
kv_cache = kv_caches[lyr] if kv_caches is not None else None
991-
y, kv_cache = layer(
998+
y, returned_cache = layer(
992999
y,
9931000
decoder_segment_ids,
9941001
decoder_positions,
@@ -1001,8 +1008,12 @@ def __call__(
10011008
attention_metadata=attention_metadata,
10021009
**layer_call_kwargs,
10031010
)
1004-
if kv_caches is not None and kv_cache is not None:
1005-
kv_caches[lyr] = kv_cache
1011+
if kv_caches is not None and returned_cache is not None:
1012+
if cfg.decoder_block != DecoderBlockType.QWEN3_NEXT:
1013+
kv_caches[lyr] = returned_cache
1014+
elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
1015+
kv_caches["key_cache"][lyr] = returned_cache[0]
1016+
kv_caches["value_cache"][lyr] = returned_cache[1]
10061017

10071018
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
10081019
visual_embeds = deepstack_visual_embeds[lyr]

0 commit comments

Comments
 (0)