Skip to content

Commit 342a72d

Browse files
Merge pull request #2907 from AI-Hypercomputer:nicogrande/vllm-adapter-compute-logits-fix
PiperOrigin-RevId: 851431043
2 parents bca71b4 + 733cbcd commit 342a72d

2 files changed

Lines changed: 21 additions & 12 deletions

File tree

src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from flax import nnx
2121
import flax.linen as nn
22+
from jax import numpy as jnp
2223
from jax.sharding import Mesh
2324
from MaxText import model_creation_utils
2425
from MaxText import max_logging
@@ -136,16 +137,13 @@ def __call__(
136137
if not isinstance(self.model, nnx.Module):
137138
raise ValueError("Model must be an instance of type nnx.Module.")
138139

139-
if input_ids.ndim < 2:
140-
input_ids = input_ids[None, :]
141-
142-
input_positions = attention_metadata.input_positions
143-
if input_positions.ndim < 2:
144-
input_positions = input_positions[None, :]
140+
# Ensure inputs are at least 2D with a batch dimension
141+
input_ids = jnp.atleast_2d(input_ids)
142+
input_positions = jnp.atleast_2d(attention_metadata.input_positions)
145143

146144
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
147145
aux_hidden_states = []
148-
hidden, kv_caches = self.model(
146+
hidden, updated_kv_caches = self.model(
149147
decoder_input_tokens=input_ids,
150148
decoder_positions=input_positions,
151149
kv_caches=kv_caches,
@@ -154,10 +152,10 @@ def __call__(
154152
**kwargs,
155153
)
156154

157-
if hidden.ndim > 1:
158-
hidden = hidden.squeeze(0)
155+
# To be compatible with vLLM, we reshape to (batch * seq, dim).
156+
hidden = hidden.reshape((-1, hidden.shape[-1]))
159157

160-
return kv_caches, hidden, aux_hidden_states
158+
return updated_kv_caches, hidden, aux_hidden_states
161159

162160
def forward(self, *args, **kwargs):
163161
"""Alias for __call__ for compatibility.
@@ -211,8 +209,14 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
211209
raise ValueError("Model is not initialized.")
212210

213211
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
214-
embeddings = self.model.token_embedder
215-
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)
212+
# Reshape to (num_tokens, 1, hidden_dim) for decoder output head
213+
y = hidden_states[:, jnp.newaxis, :]
214+
215+
# Compute logits using the MaxText decoder's output head
216+
logits = self.model.decoder.apply_output_head(self.model.token_embedder, y, True, self.model_mode)
217+
218+
# Reshape back to (num_tokens, vocab_size)
219+
return logits.squeeze(1)
216220

217221
def load_weights(self, rng_key: jax.Array) -> None:
218222
"""Loads model weights using the underlying decoder model.

src/MaxText/layers/decoders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,11 @@ def __call__(
902902
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
903903
hidden_state = y
904904

905+
# When initializing with vLLM RPA attention, we need to run the output head to
906+
# initialize any parameters associated with it.
907+
if self.is_initializing() and cfg.attention == "vllm_rpa":
908+
_ = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
909+
905910
# When invoking from vLLM with RPA attention, logit computation is deferred to a later stage.
906911
if cfg.attention == "vllm_rpa":
907912
logits = None

0 commit comments

Comments
 (0)