Skip to content

Commit 9c04205

Browse files
committed
Update vllm_decode to use pyconfig.
1 parent be963a5 commit 9c04205

5 files changed

Lines changed: 48 additions & 189 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,9 @@ use_jax_splash: false
10851085
# vLLM Adapter Configurations
10861086
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
10871087
vllm_hf_config_path: ""
1088+
# A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter.
1089+
# This can be used to override specific settings without modifying the original config file.
1090+
vllm_hf_overrides: {}
10881091
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
10891092
vllm_additional_config: {}
10901093
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]

src/maxtext/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pydantic.main import BaseModel
3939
from pydantic.types import NonNegativeFloat, NonNegativeInt, PositiveInt
4040

41+
4142
class XProfTPUPowerTraceMode(enum.IntEnum): # pylint: disable=invalid-name
4243
"""Enum for XProfTPUPowerTraceMode."""
4344

@@ -1557,6 +1558,9 @@ class VLLM(BaseModel):
15571558
max_num_batched_tokens: Optional[int] = Field(None, description="Max number of batched tokens in vLLM.")
15581559
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
15591560
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
1561+
vllm_hf_overrides: dict[str, Any] = Field(
1562+
default_factory=dict, description="Overrides for HuggingFace model config for MaxText model."
1563+
)
15601564
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
15611565

15621566

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
6969
)
7070
overrides["load_parameters_path"] = None
7171

72-
if vllm_config.model_config.hf_config_path is None:
73-
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
74-
7572
# Add base config path to positional args
7673
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
7774
argv_list = ["", str(base_config_path)]

src/maxtext/integration/vllm/maxtext_vllm_adapter/config.json

Lines changed: 0 additions & 58 deletions
This file was deleted.

src/maxtext/vllm_decode.py

Lines changed: 41 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
--use_tunix \
2626
2727
Or without Tunix using the MaxText vLLM integration:
28-
python3 -m maxtext.vllm_decode \
29-
--model_name qwen3-30b-a3b \
30-
--hf_model_name Qwen/Qwen3-30B-A3B \
31-
--hf_config_path src/MaxText/integration/vllm/maxtext_vllm_adapter \
32-
--load_parameters_path <your_checkpoint_path> \
33-
--ici_tensor_parallelism 4 \
34-
--gpu_memory_utilization 0.5 \
35-
--prompt "Suggest some famous landmarks in London."
28+
python3 -m maxtext.vllm_decode maxtext/configs/base.yml \
29+
model_name=qwen3-30b-a3b \
30+
tokenizer_path=Qwen/Qwen3-30B-A3B \
31+
vllm_hf_config_path=src/MaxText/integration/vllm/maxtext_vllm_adapter \
32+
load_parameters_path=<your_checkpoint_path> \
33+
ici_tensor_parallelism=4 \
34+
hbm_utilization_vllm=0.5 \
35+
prompt="Suggest some famous landmarks in London."
3636
"""
3737

3838
import os
@@ -62,121 +62,53 @@
6262
# --- DEFINE FLAGS GLOBALLY ---
6363
FLAGS = flags.FLAGS
6464

65-
# Parallelism
66-
flags.DEFINE_integer("ici_data_parallelism", 1, "Size of the data parallelism dimension.")
67-
flags.DEFINE_integer("ici_tensor_parallelism", 1, "Size of the non-expert tensor parallelism dimension.")
68-
flags.DEFINE_integer("ici_expert_parallelism", 1, "Size of the MoE expert parallelism dimension.")
69-
flags.DEFINE_bool("enable_dp_attention", False, "Enable attention DP parallelism")
70-
flags.DEFINE_bool("debug_sharding", False, "Debug Shardings")
71-
72-
# Model
73-
flags.DEFINE_string("model_name", None, "Model name for MaxText.")
74-
flags.DEFINE_string("hf_model_name", None, "Path to the Hugging Face model.")
75-
flags.DEFINE_string("hf_config_path", None, "Path to the local Hugging Face model config.")
76-
flags.DEFINE_string("hf_access_token", None, "Hugging Face access token for private models.")
77-
flags.DEFINE_string("tokenizer_path", None, "Path to the tokenizer. If None, use hf_model_name.")
78-
flags.DEFINE_string("load_parameters_path", None, "Path to load model parameters from.")
79-
80-
# Length/Throughput
81-
flags.DEFINE_integer("max_target_length", 1024, "Maximum total context length (MCL).")
82-
flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.")
83-
84-
# Decoding
8565
flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.")
86-
flags.DEFINE_bool("use_chat_template", False, "Whether to format the prompt using chat template.")
87-
flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.")
88-
flags.DEFINE_float("decode_sampling_temperature", 0.0, "Temperature for sampling.")
89-
flags.DEFINE_float("decode_sampling_nucleus_p", 1.0, "Nucleus sampling probability.")
90-
flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.")
9166
flags.DEFINE_integer("seed", 42, "Random seed for sampling.")
9267

93-
# Set mandatory flags
94-
flags.mark_flag_as_required("model_name")
95-
flags.mark_flag_as_required("hf_model_name")
96-
97-
98-
def decode_with_vllm(
99-
model_name: str,
100-
hf_model_name: str,
101-
hf_config_path: str | None,
102-
prompt: str,
103-
vllm_config_path: str | None = None,
104-
ici_data_parallelism: int = 1,
105-
ici_tensor_parallelism: int = 1,
106-
ici_expert_parallelism: int = 1,
107-
enable_dp_attention: bool = False,
108-
max_target_length: int = 1024,
109-
gpu_memory_utilization: float = 0.72,
110-
use_chat_template: bool = False,
111-
decode_sampling_temperature: float = 0.0,
112-
decode_sampling_nucleus_p: float = 1.0,
113-
decode_sampling_top_k: int = 1,
114-
hf_access_token: str | None = None,
115-
tokenizer_path: str | None = None,
116-
load_parameters_path: str | None = None,
117-
debug_sharding: bool = False,
118-
seed: int = 42,
119-
) -> None:
68+
69+
def decode_with_vllm(config: Config) -> None:
12070
"""Decode using vLLM with a MaxText model implementation.
12171
12272
Args:
123-
model_name: Name of the model for MaxText.
124-
hf_model_name: Path to the Hugging Face model.
125-
hf_config_path: Path to the local Hugging Face model config.
126-
prompt: The prompt to decode.
127-
ici_data_parallelism: Size of the data parallelism dimension.
128-
ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
129-
ici_expert_parallelism: Size of the MoE expert parallelism dimension.
130-
enable_dp_attention: Enable attention DP parallelism.
131-
max_target_length: Maximum total context length (MCL).
132-
gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
133-
use_chat_template: Whether to format the prompt using chat template.
134-
decode_sampling_temperature: Temperature for sampling.
135-
decode_sampling_nucleus_p: Nucleus sampling probability.
136-
decode_sampling_top_k: Top-k sampling probability.
137-
vllm_config_path: Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.
138-
hf_access_token: Hugging Face access token for private models.
139-
tokenizer_path: Path to the tokenizer. If None, use hf_model_name.
140-
load_parameters_path: Path to load model parameters from.
141-
debug_sharding: Whether to debug shardings.
142-
seed: Random seed for sampling.
73+
config: MaxText config.
14374
"""
14475
# Prepare vLLM Arguments
14576
vllm_args = {
146-
"model": hf_model_name,
147-
"max_model_len": max_target_length,
148-
"tensor_parallel_size": ici_tensor_parallelism,
149-
"data_parallel_size": ici_data_parallelism,
150-
"hf_config_path": hf_config_path,
151-
"gpu_memory_utilization": gpu_memory_utilization,
77+
"model": config.tokenizer_path,
78+
"max_model_len": config.max_target_length,
79+
"tensor_parallel_size": config.ici_tensor_parallelism,
80+
"data_parallel_size": config.ici_data_parallelism,
81+
"hf_config_path": config.vllm_hf_config_path,
82+
"hf_overrides": config.vllm_hf_overrides,
83+
"gpu_memory_utilization": config.hbm_utilization_vllm,
15284
"additional_config": {
15385
"maxtext_config": {
154-
"model_name": model_name,
86+
"model_name": config.model_name,
15587
"weight_dtype": "bfloat16",
15688
"allow_split_physical_axes": True,
157-
"debug_sharding": debug_sharding,
89+
"debug_sharding": config.debug_sharding,
15890
},
15991
"sharding": {
16092
"sharding_strategy": {
161-
"enable_dp_attention": enable_dp_attention,
93+
"enable_dp_attention": config.enable_dp_attention,
16294
},
16395
},
16496
},
16597
}
16698

167-
if load_parameters_path:
168-
vllm_args["additional_config"]["maxtext_config"]["load_parameters_path"] = load_parameters_path
99+
if config.load_parameters_path:
100+
vllm_args["additional_config"]["maxtext_config"]["load_parameters_path"] = config.load_parameters_path
169101
else:
170102
vllm_args["load_format"] = "dummy"
171103

172-
enable_expert_parallel = ici_expert_parallelism > 1
104+
enable_expert_parallel = config.ici_expert_parallelism > 1
173105
if enable_expert_parallel:
174-
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = ici_expert_parallelism
106+
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = config.ici_expert_parallelism
175107
vllm_args["enable_expert_parallel"] = enable_expert_parallel
176108

177109
max_logging.log(
178-
f"Initializing LLM with DP={ici_data_parallelism}, TP={ici_tensor_parallelism} "
179-
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
110+
f"Initializing LLM with DP={config.ici_data_parallelism}, TP={config.ici_tensor_parallelism} "
111+
f"and EP={config.ici_expert_parallelism if enable_expert_parallel else 1}..."
180112
)
181113

182114
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
@@ -188,15 +120,15 @@ def decode_with_vllm(
188120

189121
max_logging.log("Generating output...")
190122
tokenizer = transformers.AutoTokenizer.from_pretrained(
191-
tokenizer_path if tokenizer_path is not None else hf_model_name,
192-
token=hf_access_token,
123+
config.tokenizer_path,
124+
token=config.hf_access_token,
193125
)
194126

195-
prompts = [prompt]
196-
if use_chat_template:
127+
prompts = [config.prompt]
128+
if config.use_chat_template:
197129
# Format the prompt using chat template if specified
198130
messages = [
199-
{"role": "user", "content": prompt},
131+
{"role": "user", "content": config.prompt},
200132
]
201133
input_with_chat_template = tokenizer.apply_chat_template(
202134
messages,
@@ -207,18 +139,18 @@ def decode_with_vllm(
207139
prompts = [input_with_chat_template]
208140

209141
max_prompt_length = max(len(tokenizer.encode(p)) for p in prompts)
210-
max_tokens_to_generate = max_target_length - max_prompt_length
142+
max_tokens_to_generate = config.max_target_length - max_prompt_length
211143
if max_tokens_to_generate <= 0:
212144
raise ValueError(
213-
f"max_target_length ({max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
145+
f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
214146
)
215147

216148
sampling_params = SamplingParams(
217-
temperature=decode_sampling_temperature,
149+
temperature=config.decode_sampling_temperature,
218150
max_tokens=max_tokens_to_generate,
219-
top_k=decode_sampling_top_k,
220-
top_p=decode_sampling_nucleus_p,
221-
seed=seed,
151+
top_k=config.decode_sampling_top_k,
152+
top_p=config.decode_sampling_nucleus_p,
153+
seed=FLAGS.seed,
222154
)
223155

224156
outputs = llm.generate(prompts, sampling_params)
@@ -312,32 +244,13 @@ def main(argv: Sequence[str]) -> None:
312244
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
313245
)
314246

247+
config = pyconfig.initialize(argv)
248+
315249
if FLAGS.use_tunix:
316-
config = pyconfig.initialize(argv)
317250
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)
318251
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
319252
else:
320-
decode_with_vllm(
321-
model_name=FLAGS.model_name,
322-
hf_model_name=FLAGS.hf_model_name,
323-
hf_config_path=FLAGS.hf_config_path,
324-
hf_access_token=FLAGS.hf_access_token,
325-
tokenizer_path=FLAGS.tokenizer_path,
326-
load_parameters_path=FLAGS.load_parameters_path,
327-
ici_data_parallelism=FLAGS.ici_data_parallelism,
328-
ici_tensor_parallelism=FLAGS.ici_tensor_parallelism,
329-
ici_expert_parallelism=FLAGS.ici_expert_parallelism,
330-
enable_dp_attention=FLAGS.enable_dp_attention,
331-
max_target_length=FLAGS.max_target_length,
332-
gpu_memory_utilization=FLAGS.gpu_memory_utilization,
333-
prompt=FLAGS.prompt,
334-
use_chat_template=FLAGS.use_chat_template,
335-
decode_sampling_temperature=FLAGS.decode_sampling_temperature,
336-
decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p,
337-
decode_sampling_top_k=FLAGS.decode_sampling_top_k,
338-
debug_sharding=FLAGS.debug_sharding,
339-
seed=FLAGS.seed,
340-
)
253+
decode_with_vllm(config)
341254

342255

343256
if __name__ == "__main__":

0 commit comments

Comments
 (0)