1919
2020from flax import nnx
2121import flax .linen as nn
22+ from jax import numpy as jnp
2223from jax .sharding import Mesh
2324from MaxText import model_creation_utils
2425from MaxText import max_logging
@@ -136,16 +137,13 @@ def __call__(
136137 if not isinstance (self .model , nnx .Module ):
137138 raise ValueError ("Model must be an instance of type nnx.Module." )
138139
139- if input_ids .ndim < 2 :
140- input_ids = input_ids [None , :]
141-
142- input_positions = attention_metadata .input_positions
143- if input_positions .ndim < 2 :
144- input_positions = input_positions [None , :]
140+ # Ensure inputs are at least 2D with a batch dimension
141+ input_ids = jnp .atleast_2d (input_ids )
142+ input_positions = jnp .atleast_2d (attention_metadata .input_positions )
145143
146144 with self .mesh , nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
147145 aux_hidden_states = []
148- hidden , kv_caches = self .model (
146+ hidden , updated_kv_caches = self .model (
149147 decoder_input_tokens = input_ids ,
150148 decoder_positions = input_positions ,
151149 kv_caches = kv_caches ,
@@ -154,10 +152,10 @@ def __call__(
154152 ** kwargs ,
155153 )
156154
157- if hidden . ndim > 1 :
158- hidden = hidden .squeeze ( 0 )
155+ # To be compatible with vLLM, we reshape to (batch * seq, dim).
156+ hidden = hidden .reshape (( - 1 , hidden . shape [ - 1 ]) )
159157
160- return kv_caches , hidden , aux_hidden_states
158+ return updated_kv_caches , hidden , aux_hidden_states
161159
162160 def forward (self , * args , ** kwargs ):
163161 """Alias for __call__ for compatibility.
@@ -211,8 +209,14 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
211209 raise ValueError ("Model is not initialized." )
212210
213211 with self .mesh , nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
214- embeddings = self .model .token_embedder
215- return self .model .decoder .apply_output_head (embeddings , hidden_states , True , self .model_mode )
212+ # Reshape to (num_tokens, 1, hidden_dim) for decoder output head
213+ y = hidden_states [:, jnp .newaxis , :]
214+
215+ # Compute logits using the MaxText decoder's output head
216+ logits = self .model .decoder .apply_output_head (self .model .token_embedder , y , True , self .model_mode )
217+
218+ # Reshape back to (num_tokens, vocab_size)
219+ return logits .squeeze (1 )
216220
217221 def load_weights (self , rng_key : jax .Array ) -> None :
218222 """Loads model weights using the underlying decoder model.
0 commit comments