Skip to content

Commit e6976ba

Browse files
committed
adding support for attention sinks vllm.
1 parent c32eb92 commit e6976ba

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/MaxText/layers/attentions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -915,17 +915,19 @@ def forward_serve_vllm(
915915
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
916916
) from e
917917

918-
if self.config.attention_sink:
919-
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")
920-
921918
if rpa_kv_cache is None or rpa_metadata is None:
922919
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
923920

924921
query = query.reshape(-1, query.shape[2], query.shape[3])
925922
key = key.reshape(-1, key.shape[2], key.shape[3])
926923
value = value.reshape(-1, value.shape[2], value.shape[3])
927924

928-
attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
925+
if self.config.sliding_window_size > 0:
926+
attention_chunk_size = self.config.sliding_window_size
927+
else:
928+
# Chunked attention currently not used in vLLM RPA.
929+
attention_chunk_size = None
930+
929931
q_scale, k_scale, v_scale = None, None, None
930932

931933
md = rpa_metadata
@@ -940,7 +942,7 @@ def forward_serve_vllm(
940942
md.block_tables,
941943
md.query_start_loc,
942944
md.request_distribution,
943-
None,
945+
self.sinks.astype(jnp.float32) if self.sinks is not None else None,
944946
1.0,
945947
attention_chunk_size,
946948
q_scale,

0 commit comments

Comments
 (0)