Skip to content

Commit 912d0c9

Browse files
committed
Adding chat template to vllm decode.
1 parent 496ed40 commit 912d0c9

1 file changed

Lines changed: 119 additions & 79 deletions

File tree

src/maxtext/vllm_decode.py

Lines changed: 119 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import transformers
4646

4747
from maxtext.utils import model_creation_utils
48+
from maxtext.utils import max_logging
4849
from MaxText import pyconfig
4950
from MaxText.common_types import Config
5051
from MaxText.globals import MAXTEXT_CONFIGS_DIR
@@ -69,119 +70,112 @@
6970
flags.DEFINE_bool("debug_sharding", False, "Debug Shardings")
7071

7172
# Model
72-
flags.DEFINE_string("model_name", "qwen3-30b-a3b", "Model name for MaxText.")
73-
flags.DEFINE_string("hf_model_name", "Qwen/Qwen3-30B-A3B", "Path to the Hugging Face model.")
73+
flags.DEFINE_string("model_name", None, "Model name for MaxText.")
74+
flags.DEFINE_string("hf_model_name", None, "Path to the Hugging Face model.")
7475
flags.DEFINE_string("hf_config_path", None, "Path to the local Hugging Face model config.")
76+
flags.DEFINE_string("hf_access_token", None, "Hugging Face access token for private models.")
77+
flags.DEFINE_string("tokenizer_path", None, "Path to the tokenizer. If None, use hf_model_name.")
7578
flags.DEFINE_string("load_parameters_path", None, "Path to load model parameters from.")
76-
flags.DEFINE_bool("enable_expert_parallel", False, "Whether to enable expert parallelism.")
7779

7880
# Length/Throughput
7981
flags.DEFINE_integer("max_target_length", 1024, "Maximum total context length (MCL).")
80-
flags.DEFINE_integer("max_prefill_length", 512, "Maximum prefill length.")
8182
flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.")
8283

8384
# Decoding
8485
flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.")
86+
flags.DEFINE_bool("use_chat_template", False, "Whether to format the prompt using chat template.")
8587
flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.")
86-
flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.")
87-
flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.")
88+
flags.DEFINE_float("decode_sampling_temperature", 0.0, "Temperature for sampling.")
89+
flags.DEFINE_float("decode_sampling_nucleus_p", 1.0, "Nucleus sampling probability.")
8890
flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.")
91+
flags.DEFINE_integer("seed", 42, "Random seed for sampling.")
8992

90-
# Mark required flags
91-
flags.mark_flag_as_required("hf_config_path")
93+
# Set mandatory flags
94+
flags.mark_flag_as_required("model_name")
95+
flags.mark_flag_as_required("hf_model_name")
9296

9397

9498
def decode_with_vllm(
9599
model_name: str,
96100
hf_model_name: str,
97-
hf_config_path: str,
98-
load_parameters_path: str,
99-
ici_data_parallelism: int,
100-
ici_tensor_parallelism: int,
101-
ici_expert_parallelism: int,
102-
enable_dp_attention: bool,
103-
max_prefill_length: int,
104-
max_target_length: int,
105-
gpu_memory_utilization: float,
106-
enable_expert_parallel: bool,
101+
hf_config_path: str | None,
107102
prompt: str,
108-
decode_sampling_temperature: float,
109-
decode_sampling_nucleus_p: float,
110-
decode_sampling_top_k: float,
111-
debug_sharding: bool,
112103
vllm_config_path: str | None = None,
104+
ici_data_parallelism: int = 1,
105+
ici_tensor_parallelism: int = 1,
106+
ici_expert_parallelism: int = 1,
107+
enable_dp_attention: bool = False,
108+
max_target_length: int = 1024,
109+
gpu_memory_utilization: float = 0.72,
110+
use_chat_template: bool = False,
111+
decode_sampling_temperature: float = 0.0,
112+
decode_sampling_nucleus_p: float = 1.0,
113+
decode_sampling_top_k: int = 1,
114+
hf_access_token: str | None = None,
115+
tokenizer_path: str | None = None,
116+
load_parameters_path: str | None = None,
117+
debug_sharding: bool = False,
118+
seed: int = 42,
113119
) -> None:
114120
"""Decode using vLLM with a MaxText model implementation.
115121
116122
Args:
117123
model_name: Name of the model for MaxText.
118124
hf_model_name: Path to the Hugging Face model.
119125
hf_config_path: Path to the local Hugging Face model config.
120-
load_parameters_path: Path to load model parameters from.
126+
prompt: The prompt to decode.
121127
ici_data_parallelism: Size of the data parallelism dimension.
122128
ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
123-
ici_expert_parallelism: Size of the MoE expert parallelism dimension
124-
enable_dp_attention: Enable DP attention
125-
max_prefill_length: Maximum prefill length.
129+
ici_expert_parallelism: Size of the MoE expert parallelism dimension.
130+
enable_dp_attention: Enable attention DP parallelism.
126131
max_target_length: Maximum total context length (MCL).
127132
gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
128-
enable_expert_parallel: Whether to enable expert parallelism.
129-
prompt: The prompt to decode.
133+
use_chat_template: Whether to format the prompt using chat template.
130134
decode_sampling_temperature: Temperature for sampling.
131135
decode_sampling_nucleus_p: Nucleus sampling probability.
132136
decode_sampling_top_k: Top-k sampling probability.
133137
vllm_config_path: Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.
138+
hf_access_token: Hugging Face access token for private models.
139+
tokenizer_path: Path to the tokenizer. If None, use hf_model_name.
140+
load_parameters_path: Path to load model parameters from.
141+
debug_sharding: Whether to debug shardings.
142+
seed: Random seed for sampling.
134143
"""
135-
136144
# Prepare vLLM Arguments
137-
vllm_args = {}
138-
vllm_args["additional_config"] = {}
139-
140-
# Core vLLM Arguments
141-
vllm_args["model"] = hf_model_name
142-
vllm_args["max_model_len"] = max_target_length
143-
vllm_args["tensor_parallel_size"] = ici_tensor_parallelism
144-
vllm_args["data_parallel_size"] = ici_data_parallelism
145-
vllm_args["enable_expert_parallel"] = enable_expert_parallel
146-
vllm_args["hf_config_path"] = hf_config_path
147-
vllm_args["gpu_memory_utilization"] = gpu_memory_utilization
148-
149-
# Prepare MaxText and sharding configs (Parallelism is dynamic)
150-
vllm_args["additional_config"]["maxtext_config"] = {
151-
"model_name": model_name,
152-
"max_target_length": max_target_length,
153-
"weight_dtype": "bfloat16",
154-
"allow_split_physical_axes": True,
155-
"debug_sharding": debug_sharding,
145+
vllm_args = {
146+
"model": hf_model_name,
147+
"max_model_len": max_target_length,
148+
"tensor_parallel_size": ici_tensor_parallelism,
149+
"data_parallel_size": ici_data_parallelism,
150+
"hf_config_path": hf_config_path,
151+
"gpu_memory_utilization": gpu_memory_utilization,
152+
"additional_config": {
153+
"maxtext_config": {
154+
"model_name": model_name,
155+
"weight_dtype": "bfloat16",
156+
"allow_split_physical_axes": True,
157+
"debug_sharding": debug_sharding,
158+
},
159+
"sharding": {
160+
"sharding_strategy": {
161+
"enable_dp_attention": enable_dp_attention,
162+
},
163+
},
164+
},
156165
}
157-
if load_parameters_path is not None:
166+
167+
if load_parameters_path:
158168
vllm_args["additional_config"]["maxtext_config"]["load_parameters_path"] = load_parameters_path
159169
else:
160170
vllm_args["load_format"] = "dummy"
161171

162-
sharding_strategy = {
163-
"enable_dp_attention": enable_dp_attention,
164-
}
172+
enable_expert_parallel = ici_expert_parallelism > 1
165173
if enable_expert_parallel:
166-
sharding_strategy["expert_parallelism"] = ici_expert_parallelism
167-
vllm_args["additional_config"]["sharding"] = {
168-
"sharding_strategy": sharding_strategy,
169-
}
170-
171-
if enable_expert_parallel:
172-
vllm_args["additional_config"]["sharding"]["sharding_strategy"].update({"expert_parallelism": ici_expert_parallelism})
173-
174-
# Initialize and Run LLM
175-
max_tokens = max_target_length - max_prefill_length
176-
sampling_params = SamplingParams(
177-
temperature=decode_sampling_temperature,
178-
max_tokens=max_tokens,
179-
top_k=decode_sampling_top_k,
180-
top_p=decode_sampling_nucleus_p,
181-
)
174+
vllm_args["additional_config"]["sharding"]["sharding_strategy"]["expert_parallelism"] = ici_expert_parallelism
175+
vllm_args["enable_expert_parallel"] = enable_expert_parallel
182176

183-
print(
184-
f"Initializing LLM with DP={vllm_args['data_parallel_size']}, TP={vllm_args['tensor_parallel_size']} "
177+
max_logging.log(
178+
f"Initializing LLM with DP={ici_data_parallelism}, TP={ici_tensor_parallelism} "
185179
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
186180
)
187181

@@ -192,22 +186,62 @@ def decode_with_vllm(
192186
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
193187
llm = LLM(**vllm_args)
194188

195-
print("Generating output...")
196-
outputs = llm.generate([prompt], sampling_params)
189+
max_logging.log("Generating output...")
190+
tokenizer = transformers.AutoTokenizer.from_pretrained(
191+
tokenizer_path if tokenizer_path is not None else hf_model_name,
192+
token=hf_access_token,
193+
)
197194

198-
# Print Outputs
195+
prompts = [prompt]
196+
if use_chat_template:
197+
# Format the prompt using chat template if specified
198+
messages = [
199+
{"role": "user", "content": prompt},
200+
]
201+
input_with_chat_template = tokenizer.apply_chat_template(
202+
messages,
203+
tokenize=False, # Set to False to get the string
204+
add_generation_prompt=True,
205+
add_special_tokens=False, # Prevent adding special tokens
206+
)
207+
prompts = [input_with_chat_template]
208+
209+
max_prompt_length = max(len(tokenizer.encode(p)) for p in prompts)
210+
max_tokens_to_generate = max_target_length - max_prompt_length
211+
if max_tokens_to_generate <= 0:
212+
raise ValueError(
213+
f"max_target_length ({max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
214+
)
215+
216+
sampling_params = SamplingParams(
217+
temperature=decode_sampling_temperature,
218+
max_tokens=max_tokens_to_generate,
219+
top_k=decode_sampling_top_k,
220+
top_p=decode_sampling_nucleus_p,
221+
seed=seed,
222+
)
223+
224+
outputs = llm.generate(prompts, sampling_params)
225+
226+
# max_logging.log Outputs
199227
for output in outputs:
200228
prompt = output.prompt
201229
generated_text = output.outputs[0].text
202-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
230+
max_logging.log(f"Prompt: {prompt}, Generated text: {generated_text}")
203231

204232

205233
def decode_with_tunix(
206234
config: Config,
207235
model: Any,
208236
mesh: jax.sharding.Mesh,
209237
) -> None:
210-
"""Decode using vLLM with a MaxText model."""
238+
"""Decode using vLLM with a MaxText model via Tunix adapter.
239+
240+
Args:
241+
config: MaxText config.
242+
model: The MaxText model instance.
243+
mesh: The JAX mesh for parallelism.
244+
"""
211245
# Wrap the model for Tunix
212246
tunix_model = TunixMaxTextAdapter(base_model=model)
213247

@@ -235,6 +269,10 @@ def decode_with_tunix(
235269

236270
max_prompt_length = max(len(tokenizer.encode(p)) for p in prompts)
237271
max_tokens_to_generate = config.max_target_length - max_prompt_length
272+
if max_tokens_to_generate <= 0:
273+
raise ValueError(
274+
f"max_target_length ({config.max_target_length}) must be greater than max_prompt_length ({max_prompt_length})"
275+
)
238276

239277
# Create vLLM rollout for inference
240278
rollout_config = base_rollout.RolloutConfig(
@@ -262,8 +300,8 @@ def decode_with_tunix(
262300

263301
# Generate text
264302
output = vllm_rollout.generate(prompts, rollout_config)
265-
print(f"Prompt: {config.prompt}")
266-
print(f"Output: {output.text[0]}")
303+
max_logging.log(f"Prompt: {config.prompt}")
304+
max_logging.log(f"Output: {output.text[0]}")
267305

268306

269307
def main(argv: Sequence[str]) -> None:
@@ -283,20 +321,22 @@ def main(argv: Sequence[str]) -> None:
283321
model_name=FLAGS.model_name,
284322
hf_model_name=FLAGS.hf_model_name,
285323
hf_config_path=FLAGS.hf_config_path,
324+
hf_access_token=FLAGS.hf_access_token,
325+
tokenizer_path=FLAGS.tokenizer_path,
286326
load_parameters_path=FLAGS.load_parameters_path,
287327
ici_data_parallelism=FLAGS.ici_data_parallelism,
288328
ici_tensor_parallelism=FLAGS.ici_tensor_parallelism,
289329
ici_expert_parallelism=FLAGS.ici_expert_parallelism,
290330
enable_dp_attention=FLAGS.enable_dp_attention,
291331
max_target_length=FLAGS.max_target_length,
292-
max_prefill_length=FLAGS.max_prefill_length,
293332
gpu_memory_utilization=FLAGS.gpu_memory_utilization,
294-
enable_expert_parallel=FLAGS.enable_expert_parallel,
295333
prompt=FLAGS.prompt,
334+
use_chat_template=FLAGS.use_chat_template,
296335
decode_sampling_temperature=FLAGS.decode_sampling_temperature,
297336
decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p,
298337
decode_sampling_top_k=FLAGS.decode_sampling_top_k,
299338
debug_sharding=FLAGS.debug_sharding,
339+
seed=FLAGS.seed,
300340
)
301341

302342

0 commit comments

Comments
 (0)