1414
1515"""vLLM adapter for MaxText models."""
1616
17- import jax
18- import jax .numpy as jnp
1917import os
18+ import jax
2019
2120from flax import nnx
2221import flax .linen as nn
@@ -72,37 +71,38 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
7271 return maxtext_config
7372
7473
75- class MaxTextDecoderModel (nnx .Module ):
76- """A vLLM-compatible decoder model wrapper for MaxText.
74+ class MaxTextForCausalLM (nnx .Module ):
75+ """A vLLM-compatible causal language model wrapper for MaxText.
7776
78- This class adapts a MaxText model for use within the vLLM framework,
79- handling configuration generation, model initialization, and execution
77+ This class serves as the primary interface for integrating MaxText models
78+ into the vLLM serving framework, specifically for causal language modeling
79+ tasks. It handles configuration generation, model initialization, and execution
8080 of the decoding step.
8181 """
8282
83- def __init__ (self , vllm_config : VllmConfig , rng_key : jax .Array , mesh : Mesh ) -> None :
84- """Initializes the MaxTextDecoderModel .
83+ def __init__ (self , vllm_config : VllmConfig , rng_key : jax .Array , mesh : Mesh ):
84+ """Initializes the MaxTextForCausalLM model .
8585
8686 Args:
8787 vllm_config: The vLLM configuration object.
8888 rng_key: A JAX random key for model initialization.
8989 mesh: The JAX mesh device for model sharding.
9090 """
9191 self .vllm_config = vllm_config
92+ self .cfg = vllm_config .model_config
9293 self .maxtext_config = generate_maxtext_config (vllm_config )
9394
9495 # Model configuration
9596 self .mesh = mesh
9697 self .model_mode = MODEL_MODE_AUTOREGRESSIVE
98+ self .is_text_generation_model = True
9799
98100 # Model creation
99101 self .model : nnx .Module | None = None
100- self .logits : jax .Array | None = None
101102
102103 # Handle dummy weight loading during initialization
103104 if vllm_config .load_config .load_format == "dummy" :
104- with self .mesh :
105- self .load_weights (rng_key )
105+ self .load_weights (rng_key )
106106
107107 elif self .maxtext_config .load_parameters_path is None :
108108 max_logging .log ("Warning: No load_parameters_path provided. The model will be initialized with random weights." )
@@ -115,7 +115,7 @@ def __call__(
115115 * args ,
116116 ** kwargs ,
117117 ) -> tuple [list [jax .Array ], jax .Array , list [jax .Array ]]:
118- """Performs a forward pass through the decoder model.
118+ """Performs a forward pass through the causal language model.
119119
120120 Args:
121121 kv_caches: A list of JAX arrays representing the KV caches.
@@ -127,7 +127,7 @@ def __call__(
127127 Returns:
128128 A tuple containing:
129129 - updated_kv_caches: A list of updated KV caches.
130- - hidden: The hidden states (Q, d_model) .
130+ - hidden: The hidden states.
131131 - aux_hidden_states: A list of auxiliary hidden states.
132132
133133 Raises:
@@ -137,15 +137,15 @@ def __call__(
137137 raise ValueError ("Model must be an instance of type nnx.Module." )
138138
139139 if input_ids .ndim < 2 :
140- input_ids = jnp . expand_dims ( input_ids , axis = 0 )
140+ input_ids = input_ids [ None , :]
141141
142142 input_positions = attention_metadata .input_positions
143143 if input_positions .ndim < 2 :
144- input_positions = jnp . expand_dims ( input_positions , axis = 0 )
144+ input_positions = input_positions [ None , :]
145145
146- with nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
146+ with self . mesh , nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
147147 aux_hidden_states = []
148- logits , hidden , kv_caches = self .model (
148+ hidden , kv_caches = self .model (
149149 decoder_input_tokens = input_ids ,
150150 decoder_positions = input_positions ,
151151 kv_caches = kv_caches ,
@@ -154,88 +154,9 @@ def __call__(
154154 ** kwargs ,
155155 )
156156
157- if hidden .ndim > 1 :
158- hidden = jnp .squeeze (hidden , axis = 0 )
159- logits = jnp .squeeze (logits , axis = 0 )
160-
161- self .logits = nnx .data (logits ) # cache logits for compute_logits call
162-
163- return kv_caches , hidden , aux_hidden_states
164-
165- def compute_logits (self , hidden_states : jax .Array ) -> jax .Array :
166- """Computes the logits from the hidden states.
167-
168- Args:
169- hidden_states: A JAX array of hidden states.
170-
171- Returns:
172- A JAX array of logits (Q, vocab_size).
173- """
174- if self .logits is not None :
175- return self .logits
176-
177- with nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
178- embeddings = self .model .token_embedder
179- return self .model .decoder .apply_output_head (embeddings , hidden_states , True , self .model_mode )
180-
181- def load_weights (self , rng_key : jax .Array ) -> None :
182- """Loads model parameters on the provided mesh.
183-
184- Args:
185- rng_key: A JAX random key for model initialization.
186- """
187- if self .model is not None :
188- return
189-
190- with nn .logical_axis_rules ("" ):
191- model , _ = model_creation_utils .create_nnx_model (
192- self .maxtext_config , mesh = self .mesh , model_mode = self .model_mode , rng_key = rng_key
193- )
194- self .model = nnx .data (model )
195-
196-
197- class MaxTextForCausalLM (nnx .Module ):
198- """A vLLM-compatible causal language model wrapper for MaxText.
199-
200- This class serves as the primary interface for integrating MaxText models
201- into the vLLM serving framework, specifically for causal language modeling
202- tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected
203- by vLLM.
204- """
205-
206- def __init__ (self , vllm_config : VllmConfig , rng_key : jax .Array , mesh : Mesh ):
207- """Initializes the MaxTextForCausalLM model.
157+ if hidden .ndim > 1 :
158+ hidden = hidden .squeeze (0 )
208159
209- Args:
210- vllm_config: The vLLM configuration object.
211- rng_key: A JAX random key for model initialization.
212- mesh: The JAX mesh device for model sharding.
213- """
214- self .cfg = vllm_config .model_config
215- self .mesh = mesh
216- self .model = MaxTextDecoderModel (vllm_config , rng_key , mesh )
217- self .is_text_generation_model = True
218-
219- def __call__ (
220- self , kv_caches : list [jax .Array ], input_ids : jax .Array , attention_metadata : AttentionMetadata , * args , ** kwargs
221- ) -> tuple [list [jax .Array ], jax .Array ]:
222- """Performs a forward pass through the causal language model.
223-
224- Args:
225- kv_caches: A list of JAX arrays representing the KV caches.
226- input_ids: A JAX array of input token IDs.
227- attention_metadata: Attention metadata for the decoding process.
228- *args: Variable length argument list.
229- **kwargs: Arbitrary keyword arguments.
230-
231- Returns:
232- A tuple containing:
233- - updated_kv_caches: A list of updated KV caches.
234- - hidden: The hidden states.
235- - aux_hidden_states: A list of auxiliary hidden states.
236- """
237- with self .mesh :
238- kv_caches , hidden , aux_hidden_states = self .model (kv_caches , input_ids , attention_metadata , * args , ** kwargs )
239160 return kv_caches , hidden , aux_hidden_states
240161
241162 def forward (self , * args , ** kwargs ):
@@ -256,8 +177,11 @@ def get_input_embeddings(self) -> jax.Array:
256177 Returns:
257178 A JAX array representing the input embeddings.
258179 """
259- with self .mesh :
260- return self .model .model .token_embedder .embedding
180+ if not isinstance (self .model , nnx .Module ):
181+ raise ValueError ("Model is not initialized." )
182+
183+ with self .mesh , nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
184+ return self .model .token_embedder .embedding
261185
262186 def embed_input_ids (self , input_ids : jax .Array ) -> jax .Array :
263187 """Embeds the input token IDs using the model's token embedder.
@@ -268,8 +192,11 @@ def embed_input_ids(self, input_ids: jax.Array) -> jax.Array:
268192 Returns:
269193 A JAX array of embedded input tokens.
270194 """
271- with self .mesh :
272- return self .model .model .token_embedder (input_ids )
195+ if not isinstance (self .model , nnx .Module ):
196+ raise ValueError ("Model is not initialized." )
197+
198+ with self .mesh , nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
199+ return self .model .token_embedder (input_ids )
273200
274201 def compute_logits (self , hidden_states : jax .Array ) -> jax .Array :
275202 """Computes the logits from the hidden states using the underlying decoder model.
@@ -280,14 +207,24 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
280207 Returns:
281208 A JAX array of logits.
282209 """
283- with self .mesh :
284- return self .model .compute_logits (hidden_states )
210+ if not isinstance (self .model , nnx .Module ):
211+ raise ValueError ("Model is not initialized." )
212+
213+ with self .mesh , nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
214+ embeddings = self .model .token_embedder
215+ return self .model .decoder .apply_output_head (embeddings , hidden_states , True , self .model_mode )
285216
286217 def load_weights (self , rng_key : jax .Array ) -> None :
287218 """Loads model weights using the underlying decoder model.
288219
289220 Args:
290221 rng_key: A JAX random key for model initialization.
291222 """
292- with self .mesh :
293- self .model .load_weights (rng_key )
223+ if self .model is not None :
224+ return
225+
226+ with self .mesh , nn .logical_axis_rules ("" ):
227+ model , _ = model_creation_utils .create_nnx_model (
228+ self .maxtext_config , mesh = self .mesh , model_mode = self .model_mode , rng_key = rng_key
229+ )
230+ self .model = nnx .data (model )
0 commit comments