Skip to content

Commit 05a4a53

Browse files
Merge pull request #2923 from AI-Hypercomputer:nicogrande/enable-gpt-oss-attention-vllm
PiperOrigin-RevId: 855388150
2 parents f27ac67 + e6976ba commit 05a4a53

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/MaxText/layers/attentions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -916,17 +916,19 @@ def forward_serve_vllm(
916916
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
917917
) from e
918918

919-
if self.config.attention_sink:
920-
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")
921-
922919
if rpa_kv_cache is None or rpa_metadata is None:
923920
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")
924921

925922
query = query.reshape(-1, query.shape[2], query.shape[3])
926923
key = key.reshape(-1, key.shape[2], key.shape[3])
927924
value = value.reshape(-1, value.shape[2], value.shape[3])
928925

929-
attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
926+
if self.config.sliding_window_size > 0:
927+
attention_chunk_size = self.config.sliding_window_size
928+
else:
929+
# Chunked attention currently not used in vLLM RPA.
930+
attention_chunk_size = None
931+
930932
q_scale, k_scale, v_scale = None, None, None
931933

932934
md = rpa_metadata
@@ -941,7 +943,7 @@ def forward_serve_vllm(
941943
md.block_tables,
942944
md.query_start_loc,
943945
md.request_distribution,
944-
None,
946+
self.sinks.astype(jnp.float32) if self.sinks is not None else None,
945947
1.0,
946948
attention_chunk_size,
947949
q_scale,

0 commit comments

Comments
 (0)