@@ -1146,6 +1146,9 @@ def copy(path, partial_cache, full_cache, annotations):
11461146 "cached_prefill_value_scale" ,
11471147 ]:
11481148 full_cache = jax .lax .dynamic_update_index_in_dim (full_cache , partial_cache , slot , batch_idx )
1149+ elif path_key in ["recurrent_state" , "conv_state" ]:
1150+ # Direct update for fixed-size linear attention states
1151+ full_cache = jax .lax .dynamic_update_index_in_dim (full_cache , partial_cache , slot , batch_idx )
11491152 else :
11501153 raise ValueError (f"We don't have a strategy for inserting { path_key } " )
11511154
@@ -1258,6 +1261,10 @@ def copy(path, partial_cache, full_cache, annotations):
12581261 "cached_prefill_value_scale" ,
12591262 ]:
12601263 return jax .lax .dynamic_update_index_in_dim (full_cache , partial_cache , slot , batch_idx )
1264+ elif path_key in ["recurrent_state" , "conv_state" ]:
1265+ # For linear attention, the state is fixed size. We simply copy the result
1266+ # from the prefill step (partial_cache) into the decode state (full_cache).
1267+ return jax .lax .dynamic_update_index_in_dim (full_cache , partial_cache , slot , batch_idx )
12611268 else :
12621269 raise ValueError (f"We don't have a strategy for inserting { path_key } " )
12631270
@@ -1447,6 +1454,15 @@ def copy(path, partial_cache, full_cache, annotations):
14471454 partial_cache = jax .lax .dynamic_slice (partial_cache , start_indices , slice_size )
14481455
14491456 return jax .lax .dynamic_update_index_in_dim (full_cache , partial_cache , slot , batch_idx )
1457+ elif path_key in ["recurrent_state" , "conv_state" ]:
1458+ # SSM states are the "final state" after prefill, so we just overwrite the slot.
1459+ # We don't need to slice by sequence length like we do for KV cache.
1460+ if num_prompts > 1 :
1461+ raise NotImplementedError (
1462+ "Packed prefill is currently incompatible with linear attention states (GDN). "
1463+ "Prompt memory will bleed into adjacent prompts. Please disable packed prefill."
1464+ )
1465+ return jax .lax .dynamic_update_index_in_dim (full_cache , partial_cache , slot , batch_idx )
14501466 else :
14511467 raise ValueError (f"We don't have a strategy for inserting { path_key } " )
14521468
@@ -1660,7 +1676,13 @@ def initialize():
16601676 def is_lp (k ):
16611677 return isinstance (k , flax .linen .spmd .LogicallyPartitioned )
16621678
1663- self .kv_cache_annotations_named = jax .tree_util .tree_map (lambda x : tuple (x .names ), cache , is_leaf = is_lp )
1679+ self .kv_cache_annotations_named = jax .tree_util .tree_map (
1680+ lambda x : tuple (x .logical_axes )
1681+ if hasattr (x , "logical_axes" )
1682+ else (tuple (x .names ) if hasattr (x , "names" ) else ()),
1683+ cache ,
1684+ is_leaf = is_lp ,
1685+ )
16641686 zeroed = max_utils .unbox_logicallypartioned (init_state )
16651687 return zeroed
16661688
0 commit comments