2525 --use_tunix \
2626
2727Or without Tunix using the MaxText vLLM integration:
28- python3 -m maxtext.vllm_decode \
29- -- model_name qwen3-30b-a3b \
30- --hf_model_name Qwen/Qwen3-30B-A3B \
31- --hf_config_path src/MaxText/integration/vllm/maxtext_vllm_adapter \
32- -- load_parameters_path <your_checkpoint_path> \
33- -- ici_tensor_parallelism 4 \
34- --gpu_memory_utilization 0.5 \
35- -- prompt "Suggest some famous landmarks in London."
28+ python3 -m maxtext.vllm_decode maxtext/configs/base.yml \
29+ model_name= qwen3-30b-a3b \
30+ tokenizer_path= Qwen/Qwen3-30B-A3B \
31+ vllm_hf_config_path= src/MaxText/integration/vllm/maxtext_vllm_adapter \
32+ load_parameters_path= <your_checkpoint_path> \
33+ ici_tensor_parallelism= 4 \
34+ hbm_utilization_vllm= 0.5 \
35+ prompt= "Suggest some famous landmarks in London."
3636"""
3737
3838import os
6262# --- DEFINE FLAGS GLOBALLY ---
6363FLAGS = flags .FLAGS
6464
65- # Parallelism
66- flags .DEFINE_integer ("ici_data_parallelism" , 1 , "Size of the data parallelism dimension." )
67- flags .DEFINE_integer ("ici_tensor_parallelism" , 1 , "Size of the non-expert tensor parallelism dimension." )
68- flags .DEFINE_integer ("ici_expert_parallelism" , 1 , "Size of the MoE expert parallelism dimension." )
69- flags .DEFINE_bool ("enable_dp_attention" , False , "Enable attention DP parallelism" )
70- flags .DEFINE_bool ("debug_sharding" , False , "Debug Shardings" )
71-
72- # Model
73- flags .DEFINE_string ("model_name" , None , "Model name for MaxText." )
74- flags .DEFINE_string ("hf_model_name" , None , "Path to the Hugging Face model." )
75- flags .DEFINE_string ("hf_config_path" , None , "Path to the local Hugging Face model config." )
76- flags .DEFINE_string ("hf_access_token" , None , "Hugging Face access token for private models." )
77- flags .DEFINE_string ("tokenizer_path" , None , "Path to the tokenizer. If None, use hf_model_name." )
78- flags .DEFINE_string ("load_parameters_path" , None , "Path to load model parameters from." )
79-
80- # Length/Throughput
81- flags .DEFINE_integer ("max_target_length" , 1024 , "Maximum total context length (MCL)." )
82- flags .DEFINE_float ("gpu_memory_utilization" , 0.72 , "Fraction of GPU memory to be used for the model executor." )
83-
84- # Decoding
8565flags .DEFINE_bool ("use_tunix" , False , "Whether to use Tunix for vLLM decoding." )
86- flags .DEFINE_bool ("use_chat_template" , False , "Whether to format the prompt using chat template." )
87- flags .DEFINE_string ("prompt" , "Suggest some famous landmarks in London." , "The prompt to decode." )
88- flags .DEFINE_float ("decode_sampling_temperature" , 0.0 , "Temperature for sampling." )
89- flags .DEFINE_float ("decode_sampling_nucleus_p" , 1.0 , "Nucleus sampling probability." )
90- flags .DEFINE_integer ("decode_sampling_top_k" , 1 , "Top-k sampling probability." )
9166flags .DEFINE_integer ("seed" , 42 , "Random seed for sampling." )
9267
93- # Set mandatory flags
94- flags .mark_flag_as_required ("model_name" )
95- flags .mark_flag_as_required ("hf_model_name" )
96-
97-
98- def decode_with_vllm (
99- model_name : str ,
100- hf_model_name : str ,
101- hf_config_path : str | None ,
102- prompt : str ,
103- vllm_config_path : str | None = None ,
104- ici_data_parallelism : int = 1 ,
105- ici_tensor_parallelism : int = 1 ,
106- ici_expert_parallelism : int = 1 ,
107- enable_dp_attention : bool = False ,
108- max_target_length : int = 1024 ,
109- gpu_memory_utilization : float = 0.72 ,
110- use_chat_template : bool = False ,
111- decode_sampling_temperature : float = 0.0 ,
112- decode_sampling_nucleus_p : float = 1.0 ,
113- decode_sampling_top_k : int = 1 ,
114- hf_access_token : str | None = None ,
115- tokenizer_path : str | None = None ,
116- load_parameters_path : str | None = None ,
117- debug_sharding : bool = False ,
118- seed : int = 42 ,
119- ) -> None :
68+
69+ def decode_with_vllm (config : Config ) -> None :
12070 """Decode using vLLM with a MaxText model implementation.
12171
12272 Args:
123- model_name: Name of the model for MaxText.
124- hf_model_name: Path to the Hugging Face model.
125- hf_config_path: Path to the local Hugging Face model config.
126- prompt: The prompt to decode.
127- ici_data_parallelism: Size of the data parallelism dimension.
128- ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
129- ici_expert_parallelism: Size of the MoE expert parallelism dimension.
130- enable_dp_attention: Enable attention DP parallelism.
131- max_target_length: Maximum total context length (MCL).
132- gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
133- use_chat_template: Whether to format the prompt using chat template.
134- decode_sampling_temperature: Temperature for sampling.
135- decode_sampling_nucleus_p: Nucleus sampling probability.
136- decode_sampling_top_k: Top-k sampling probability.
137- vllm_config_path: Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.
138- hf_access_token: Hugging Face access token for private models.
139- tokenizer_path: Path to the tokenizer. If None, use hf_model_name.
140- load_parameters_path: Path to load model parameters from.
141- debug_sharding: Whether to debug shardings.
142- seed: Random seed for sampling.
73+ config: MaxText config.
14374 """
14475 # Prepare vLLM Arguments
14576 vllm_args = {
146- "model" : hf_model_name ,
147- "max_model_len" : max_target_length ,
148- "tensor_parallel_size" : ici_tensor_parallelism ,
149- "data_parallel_size" : ici_data_parallelism ,
150- "hf_config_path" : hf_config_path ,
151- "gpu_memory_utilization" : gpu_memory_utilization ,
77+ "model" : config .tokenizer_path ,
78+ "max_model_len" : config .max_target_length ,
79+ "tensor_parallel_size" : config .ici_tensor_parallelism ,
80+ "data_parallel_size" : config .ici_data_parallelism ,
81+ "hf_config_path" : config .vllm_hf_config_path ,
82+ "hf_overrides" : config .vllm_hf_overrides ,
83+ "gpu_memory_utilization" : config .hbm_utilization_vllm ,
15284 "additional_config" : {
15385 "maxtext_config" : {
154- "model_name" : model_name ,
86+ "model_name" : config . model_name ,
15587 "weight_dtype" : "bfloat16" ,
15688 "allow_split_physical_axes" : True ,
157- "debug_sharding" : debug_sharding ,
89+ "debug_sharding" : config . debug_sharding ,
15890 },
15991 "sharding" : {
16092 "sharding_strategy" : {
161- "enable_dp_attention" : enable_dp_attention ,
93+ "enable_dp_attention" : config . enable_dp_attention ,
16294 },
16395 },
16496 },
16597 }
16698
167- if load_parameters_path :
168- vllm_args ["additional_config" ]["maxtext_config" ]["load_parameters_path" ] = load_parameters_path
99+ if config . load_parameters_path :
100+ vllm_args ["additional_config" ]["maxtext_config" ]["load_parameters_path" ] = config . load_parameters_path
169101 else :
170102 vllm_args ["load_format" ] = "dummy"
171103
172- enable_expert_parallel = ici_expert_parallelism > 1
104+ enable_expert_parallel = config . ici_expert_parallelism > 1
173105 if enable_expert_parallel :
174- vllm_args ["additional_config" ]["sharding" ]["sharding_strategy" ]["expert_parallelism" ] = ici_expert_parallelism
106+ vllm_args ["additional_config" ]["sharding" ]["sharding_strategy" ]["expert_parallelism" ] = config . ici_expert_parallelism
175107 vllm_args ["enable_expert_parallel" ] = enable_expert_parallel
176108
177109 max_logging .log (
178- f"Initializing LLM with DP={ ici_data_parallelism } , TP={ ici_tensor_parallelism } "
179- f"and EP={ ici_expert_parallelism if enable_expert_parallel else 0 } ..."
110+ f"Initializing LLM with DP={ config . ici_data_parallelism } , TP={ config . ici_tensor_parallelism } "
111+ f"and EP={ config . ici_expert_parallelism if enable_expert_parallel else 1 } ..."
180112 )
181113
182114 vllm_config_path = os .path .join (MAXTEXT_CONFIGS_DIR , "inference" , "vllm.yml" )
@@ -188,15 +120,15 @@ def decode_with_vllm(
188120
189121 max_logging .log ("Generating output..." )
190122 tokenizer = transformers .AutoTokenizer .from_pretrained (
191- tokenizer_path if tokenizer_path is not None else hf_model_name ,
192- token = hf_access_token ,
123+ config . tokenizer_path ,
124+ token = config . hf_access_token ,
193125 )
194126
195- prompts = [prompt ]
196- if use_chat_template :
127+ prompts = [config . prompt ]
128+ if config . use_chat_template :
197129 # Format the prompt using chat template if specified
198130 messages = [
199- {"role" : "user" , "content" : prompt },
131+ {"role" : "user" , "content" : config . prompt },
200132 ]
201133 input_with_chat_template = tokenizer .apply_chat_template (
202134 messages ,
@@ -207,18 +139,18 @@ def decode_with_vllm(
207139 prompts = [input_with_chat_template ]
208140
209141 max_prompt_length = max (len (tokenizer .encode (p )) for p in prompts )
210- max_tokens_to_generate = max_target_length - max_prompt_length
142+ max_tokens_to_generate = config . max_target_length - max_prompt_length
211143 if max_tokens_to_generate <= 0 :
212144 raise ValueError (
213- f"max_target_length ({ max_target_length } ) must be greater than max_prompt_length ({ max_prompt_length } )"
145+ f"max_target_length ({ config . max_target_length } ) must be greater than max_prompt_length ({ max_prompt_length } )"
214146 )
215147
216148 sampling_params = SamplingParams (
217- temperature = decode_sampling_temperature ,
149+ temperature = config . decode_sampling_temperature ,
218150 max_tokens = max_tokens_to_generate ,
219- top_k = decode_sampling_top_k ,
220- top_p = decode_sampling_nucleus_p ,
221- seed = seed ,
151+ top_k = config . decode_sampling_top_k ,
152+ top_p = config . decode_sampling_nucleus_p ,
153+ seed = FLAGS . seed ,
222154 )
223155
224156 outputs = llm .generate (prompts , sampling_params )
@@ -312,32 +244,13 @@ def main(argv: Sequence[str]) -> None:
312244 os .environ .get ("LIBTPU_INIT_ARGS" , "" ) + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
313245 )
314246
247+ config = pyconfig .initialize (argv )
248+
315249 if FLAGS .use_tunix :
316- config = pyconfig .initialize (argv )
317250 maxtext_model , mesh = model_creation_utils .create_nnx_model (config )
318251 decode_with_tunix (config , model = maxtext_model , mesh = mesh )
319252 else :
320- decode_with_vllm (
321- model_name = FLAGS .model_name ,
322- hf_model_name = FLAGS .hf_model_name ,
323- hf_config_path = FLAGS .hf_config_path ,
324- hf_access_token = FLAGS .hf_access_token ,
325- tokenizer_path = FLAGS .tokenizer_path ,
326- load_parameters_path = FLAGS .load_parameters_path ,
327- ici_data_parallelism = FLAGS .ici_data_parallelism ,
328- ici_tensor_parallelism = FLAGS .ici_tensor_parallelism ,
329- ici_expert_parallelism = FLAGS .ici_expert_parallelism ,
330- enable_dp_attention = FLAGS .enable_dp_attention ,
331- max_target_length = FLAGS .max_target_length ,
332- gpu_memory_utilization = FLAGS .gpu_memory_utilization ,
333- prompt = FLAGS .prompt ,
334- use_chat_template = FLAGS .use_chat_template ,
335- decode_sampling_temperature = FLAGS .decode_sampling_temperature ,
336- decode_sampling_nucleus_p = FLAGS .decode_sampling_nucleus_p ,
337- decode_sampling_top_k = FLAGS .decode_sampling_top_k ,
338- debug_sharding = FLAGS .debug_sharding ,
339- seed = FLAGS .seed ,
340- )
253+ decode_with_vllm (config )
341254
342255
343256if __name__ == "__main__" :
0 commit comments