Skip to content

Commit 61abdbc

Browse files
committed
Simplifying maxtext vllm adapter implementation.
updating example vllm_decode.
1 parent c2574ab commit 61abdbc

4 files changed

Lines changed: 58 additions & 120 deletions

File tree

src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 43 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
"""vLLM adapter for MaxText models."""
1616

17-
import jax
18-
import jax.numpy as jnp
1917
import os
18+
import jax
2019

2120
from flax import nnx
2221
import flax.linen as nn
@@ -72,37 +71,38 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
7271
return maxtext_config
7372

7473

75-
class MaxTextDecoderModel(nnx.Module):
76-
"""A vLLM-compatible decoder model wrapper for MaxText.
74+
class MaxTextForCausalLM(nnx.Module):
75+
"""A vLLM-compatible causal language model wrapper for MaxText.
7776
78-
This class adapts a MaxText model for use within the vLLM framework,
79-
handling configuration generation, model initialization, and execution
77+
This class serves as the primary interface for integrating MaxText models
78+
into the vLLM serving framework, specifically for causal language modeling
79+
tasks. It handles configuration generation, model initialization, and execution
8080
of the decoding step.
8181
"""
8282

83-
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> None:
84-
"""Initializes the MaxTextDecoderModel.
83+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
84+
"""Initializes the MaxTextForCausalLM model.
8585
8686
Args:
8787
vllm_config: The vLLM configuration object.
8888
rng_key: A JAX random key for model initialization.
8989
mesh: The JAX mesh device for model sharding.
9090
"""
9191
self.vllm_config = vllm_config
92+
self.cfg = vllm_config.model_config
9293
self.maxtext_config = generate_maxtext_config(vllm_config)
9394

9495
# Model configuration
9596
self.mesh = mesh
9697
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
98+
self.is_text_generation_model = True
9799

98100
# Model creation
99101
self.model: nnx.Module | None = None
100-
self.logits: jax.Array | None = None
101102

102103
# Handle dummy weight loading during initialization
103104
if vllm_config.load_config.load_format == "dummy":
104-
with self.mesh:
105-
self.load_weights(rng_key)
105+
self.load_weights(rng_key)
106106

107107
elif self.maxtext_config.load_parameters_path is None:
108108
max_logging.log("Warning: No load_parameters_path provided. The model will be initialized with random weights.")
@@ -115,7 +115,7 @@ def __call__(
115115
*args,
116116
**kwargs,
117117
) -> tuple[list[jax.Array], jax.Array, list[jax.Array]]:
118-
"""Performs a forward pass through the decoder model.
118+
"""Performs a forward pass through the causal language model.
119119
120120
Args:
121121
kv_caches: A list of JAX arrays representing the KV caches.
@@ -127,7 +127,7 @@ def __call__(
127127
Returns:
128128
A tuple containing:
129129
- updated_kv_caches: A list of updated KV caches.
130-
- hidden: The hidden states (Q, d_model).
130+
- hidden: The hidden states.
131131
- aux_hidden_states: A list of auxiliary hidden states.
132132
133133
Raises:
@@ -137,15 +137,15 @@ def __call__(
137137
raise ValueError("Model must be an instance of type nnx.Module.")
138138

139139
if input_ids.ndim < 2:
140-
input_ids = jnp.expand_dims(input_ids, axis=0)
140+
input_ids = input_ids[None, :]
141141

142142
input_positions = attention_metadata.input_positions
143143
if input_positions.ndim < 2:
144-
input_positions = jnp.expand_dims(input_positions, axis=0)
144+
input_positions = input_positions[None, :]
145145

146-
with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
146+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
147147
aux_hidden_states = []
148-
logits, hidden, kv_caches = self.model(
148+
hidden, kv_caches = self.model(
149149
decoder_input_tokens=input_ids,
150150
decoder_positions=input_positions,
151151
kv_caches=kv_caches,
@@ -154,88 +154,9 @@ def __call__(
154154
**kwargs,
155155
)
156156

157-
if hidden.ndim > 1:
158-
hidden = jnp.squeeze(hidden, axis=0)
159-
logits = jnp.squeeze(logits, axis=0)
160-
161-
self.logits = nnx.data(logits) # cache logits for compute_logits call
162-
163-
return kv_caches, hidden, aux_hidden_states
164-
165-
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
166-
"""Computes the logits from the hidden states.
167-
168-
Args:
169-
hidden_states: A JAX array of hidden states.
170-
171-
Returns:
172-
A JAX array of logits (Q, vocab_size).
173-
"""
174-
if self.logits is not None:
175-
return self.logits
176-
177-
with nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
178-
embeddings = self.model.token_embedder
179-
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)
180-
181-
def load_weights(self, rng_key: jax.Array) -> None:
182-
"""Loads model parameters on the provided mesh.
183-
184-
Args:
185-
rng_key: A JAX random key for model initialization.
186-
"""
187-
if self.model is not None:
188-
return
189-
190-
with nn.logical_axis_rules(""):
191-
model, _ = model_creation_utils.create_nnx_model(
192-
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
193-
)
194-
self.model = nnx.data(model)
195-
196-
197-
class MaxTextForCausalLM(nnx.Module):
198-
"""A vLLM-compatible causal language model wrapper for MaxText.
199-
200-
This class serves as the primary interface for integrating MaxText models
201-
into the vLLM serving framework, specifically for causal language modeling
202-
tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected
203-
by vLLM.
204-
"""
205-
206-
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
207-
"""Initializes the MaxTextForCausalLM model.
157+
if hidden.ndim > 1:
158+
hidden = hidden.squeeze(0)
208159

209-
Args:
210-
vllm_config: The vLLM configuration object.
211-
rng_key: A JAX random key for model initialization.
212-
mesh: The JAX mesh device for model sharding.
213-
"""
214-
self.cfg = vllm_config.model_config
215-
self.mesh = mesh
216-
self.model = MaxTextDecoderModel(vllm_config, rng_key, mesh)
217-
self.is_text_generation_model = True
218-
219-
def __call__(
220-
self, kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, **kwargs
221-
) -> tuple[list[jax.Array], jax.Array]:
222-
"""Performs a forward pass through the causal language model.
223-
224-
Args:
225-
kv_caches: A list of JAX arrays representing the KV caches.
226-
input_ids: A JAX array of input token IDs.
227-
attention_metadata: Attention metadata for the decoding process.
228-
*args: Variable length argument list.
229-
**kwargs: Arbitrary keyword arguments.
230-
231-
Returns:
232-
A tuple containing:
233-
- updated_kv_caches: A list of updated KV caches.
234-
- hidden: The hidden states.
235-
- aux_hidden_states: A list of auxiliary hidden states.
236-
"""
237-
with self.mesh:
238-
kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs)
239160
return kv_caches, hidden, aux_hidden_states
240161

241162
def forward(self, *args, **kwargs):
@@ -256,8 +177,11 @@ def get_input_embeddings(self) -> jax.Array:
256177
Returns:
257178
A JAX array representing the input embeddings.
258179
"""
259-
with self.mesh:
260-
return self.model.model.token_embedder.embedding
180+
if not isinstance(self.model, nnx.Module):
181+
raise ValueError("Model is not initialized.")
182+
183+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
184+
return self.model.token_embedder.embedding
261185

262186
def embed_input_ids(self, input_ids: jax.Array) -> jax.Array:
263187
"""Embeds the input token IDs using the model's token embedder.
@@ -268,8 +192,11 @@ def embed_input_ids(self, input_ids: jax.Array) -> jax.Array:
268192
Returns:
269193
A JAX array of embedded input tokens.
270194
"""
271-
with self.mesh:
272-
return self.model.model.token_embedder(input_ids)
195+
if not isinstance(self.model, nnx.Module):
196+
raise ValueError("Model is not initialized.")
197+
198+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
199+
return self.model.token_embedder(input_ids)
273200

274201
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
275202
"""Computes the logits from the hidden states using the underlying decoder model.
@@ -280,14 +207,24 @@ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
280207
Returns:
281208
A JAX array of logits.
282209
"""
283-
with self.mesh:
284-
return self.model.compute_logits(hidden_states)
210+
if not isinstance(self.model, nnx.Module):
211+
raise ValueError("Model is not initialized.")
212+
213+
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
214+
embeddings = self.model.token_embedder
215+
return self.model.decoder.apply_output_head(embeddings, hidden_states, True, self.model_mode)
285216

286217
def load_weights(self, rng_key: jax.Array) -> None:
287218
"""Loads model weights using the underlying decoder model.
288219
289220
Args:
290221
rng_key: A JAX random key for model initialization.
291222
"""
292-
with self.mesh:
293-
self.model.load_weights(rng_key)
223+
if self.model is not None:
224+
return
225+
226+
with self.mesh, nn.logical_axis_rules(""):
227+
model, _ = model_creation_utils.create_nnx_model(
228+
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
229+
)
230+
self.model = nnx.data(model)

src/MaxText/layers/decoders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,11 +902,16 @@ def __call__(
902902
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
903903
hidden_state = y
904904

905+
# When invoking from vLLM with RPA attention, logit computation is deferred to a later stage.
906+
if cfg.attention == "vllm_rpa":
907+
logits = None
908+
905909
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
906910
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
907-
if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
911+
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
908912
logits = None
909913
self.sow("intermediates", "hidden_states", hidden_state)
914+
910915
else:
911916
logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode)
912917

src/MaxText/layers/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __call__(
211211

212212
if self.config.attention == "vllm_rpa":
213213
# In vLLM, logits are computed separately after updating the KV cache.
214-
return logits, hidden_state, kv_caches
214+
return hidden_state, kv_caches
215215

216216
return logits
217217

@@ -514,7 +514,7 @@ def __call__(
514514

515515
if self.config.attention == "vllm_rpa":
516516
# In vLLM, logits are computed separately after updating the KV cache.
517-
return logits, hidden_state, kv_caches
517+
return hidden_state, kv_caches
518518

519519
return logits
520520

src/MaxText/vllm_decode.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,13 @@
2626
2727
Or without Tunix using the MaxText vLLM integration:
2828
python3 -m MaxText.vllm_decode \
29-
--model-name qwen3-30b-a3b \
30-
--hf-model-name Qwen/Qwen3-30B-A3B \
31-
--hf-config-path src/MaxText/integration/vllm/maxtext_vllm_adapter \
32-
--load-parameters-path <your_checkpoint_path> \
33-
--ici_data_parallelism 1 \
34-
--ici-tensor-parallelism 4 \
35-
--ici-expert-parallelism 1 \
36-
--max-model-len 4096 \
37-
--max-num-batched-tokens 262144 \
38-
--gpu-memory-utilization 0.5 \
39-
--prompt "Suggest some famous landmarks in London." \
29+
--model_name qwen3-30b-a3b \
30+
--hf_model_name Qwen/Qwen3-30B-A3B \
31+
--hf_config_path src/MaxText/integration/vllm/maxtext_vllm_adapter \
32+
--load_parameters_path <your_checkpoint_path> \
33+
--ici_tensor_parallelism 4 \
34+
--gpu_memory_utilization 0.5 \
35+
--prompt "Suggest some famous landmarks in London."
4036
"""
4137

4238
import os

0 commit comments

Comments
 (0)