4545import transformers
4646
4747from maxtext .utils import model_creation_utils
48+ from maxtext .utils import max_logging
4849from MaxText import pyconfig
4950from MaxText .common_types import Config
5051from MaxText .globals import MAXTEXT_CONFIGS_DIR
6970flags .DEFINE_bool ("debug_sharding" , False , "Debug Shardings" )
7071
7172# Model
72- flags .DEFINE_string ("model_name" , "qwen3-30b-a3b" , "Model name for MaxText." )
73- flags .DEFINE_string ("hf_model_name" , "Qwen/Qwen3-30B-A3B" , "Path to the Hugging Face model." )
73+ flags .DEFINE_string ("model_name" , None , "Model name for MaxText." )
74+ flags .DEFINE_string ("hf_model_name" , None , "Path to the Hugging Face model." )
7475flags .DEFINE_string ("hf_config_path" , None , "Path to the local Hugging Face model config." )
76+ flags .DEFINE_string ("hf_access_token" , None , "Hugging Face access token for private models." )
77+ flags .DEFINE_string ("tokenizer_path" , None , "Path to the tokenizer. If None, use hf_model_name." )
7578flags .DEFINE_string ("load_parameters_path" , None , "Path to load model parameters from." )
76- flags .DEFINE_bool ("enable_expert_parallel" , False , "Whether to enable expert parallelism." )
7779
7880# Length/Throughput
7981flags .DEFINE_integer ("max_target_length" , 1024 , "Maximum total context length (MCL)." )
80- flags .DEFINE_integer ("max_prefill_length" , 512 , "Maximum prefill length." )
8182flags .DEFINE_float ("gpu_memory_utilization" , 0.72 , "Fraction of GPU memory to be used for the model executor." )
8283
8384# Decoding
8485flags .DEFINE_bool ("use_tunix" , False , "Whether to use Tunix for vLLM decoding." )
86+ flags .DEFINE_bool ("use_chat_template" , False , "Whether to format the prompt using chat template." )
8587flags .DEFINE_string ("prompt" , "Suggest some famous landmarks in London." , "The prompt to decode." )
86- flags .DEFINE_integer ("decode_sampling_temperature" , 0 , "Temperature for sampling." )
87- flags .DEFINE_integer ("decode_sampling_nucleus_p" , 1 , "Nucleus sampling probability." )
88+ flags .DEFINE_float ("decode_sampling_temperature" , 0. 0 , "Temperature for sampling." )
89+ flags .DEFINE_float ("decode_sampling_nucleus_p" , 1.0 , "Nucleus sampling probability." )
8890flags .DEFINE_integer ("decode_sampling_top_k" , 1 , "Top-k sampling probability." )
91+ flags .DEFINE_integer ("seed" , 42 , "Random seed for sampling." )
8992
90- # Mark required flags
91- flags .mark_flag_as_required ("hf_config_path" )
93+ # Set mandatory flags
94+ flags .mark_flag_as_required ("model_name" )
95+ flags .mark_flag_as_required ("hf_model_name" )
9296
9397
9498def decode_with_vllm (
9599 model_name : str ,
96100 hf_model_name : str ,
97- hf_config_path : str ,
98- load_parameters_path : str ,
99- ici_data_parallelism : int ,
100- ici_tensor_parallelism : int ,
101- ici_expert_parallelism : int ,
102- enable_dp_attention : bool ,
103- max_prefill_length : int ,
104- max_target_length : int ,
105- gpu_memory_utilization : float ,
106- enable_expert_parallel : bool ,
101+ hf_config_path : str | None ,
107102 prompt : str ,
108- decode_sampling_temperature : float ,
109- decode_sampling_nucleus_p : float ,
110- decode_sampling_top_k : float ,
111- debug_sharding : bool ,
112103 vllm_config_path : str | None = None ,
104+ ici_data_parallelism : int = 1 ,
105+ ici_tensor_parallelism : int = 1 ,
106+ ici_expert_parallelism : int = 1 ,
107+ enable_dp_attention : bool = False ,
108+ max_target_length : int = 1024 ,
109+ gpu_memory_utilization : float = 0.72 ,
110+ use_chat_template : bool = False ,
111+ decode_sampling_temperature : float = 0.0 ,
112+ decode_sampling_nucleus_p : float = 1.0 ,
113+ decode_sampling_top_k : int = 1 ,
114+ hf_access_token : str | None = None ,
115+ tokenizer_path : str | None = None ,
116+ load_parameters_path : str | None = None ,
117+ debug_sharding : bool = False ,
118+ seed : int = 42 ,
113119) -> None :
114120 """Decode using vLLM with a MaxText model implementation.
115121
116122 Args:
117123 model_name: Name of the model for MaxText.
118124 hf_model_name: Path to the Hugging Face model.
119125 hf_config_path: Path to the local Hugging Face model config.
120- load_parameters_path: Path to load model parameters from .
126+ prompt: The prompt to decode .
121127 ici_data_parallelism: Size of the data parallelism dimension.
122128 ici_tensor_parallelism: Size of the non-expert tensor parallelism dimension.
123- ici_expert_parallelism: Size of the MoE expert parallelism dimension
124- enable_dp_attention: Enable DP attention
125- max_prefill_length: Maximum prefill length.
129+ ici_expert_parallelism: Size of the MoE expert parallelism dimension.
130+ enable_dp_attention: Enable attention DP parallelism.
126131 max_target_length: Maximum total context length (MCL).
127132 gpu_memory_utilization: Fraction of GPU memory to be used for the model executor.
128- enable_expert_parallel: Whether to enable expert parallelism.
129- prompt: The prompt to decode.
133+ use_chat_template: Whether to format the prompt using chat template.
130134 decode_sampling_temperature: Temperature for sampling.
131135 decode_sampling_nucleus_p: Nucleus sampling probability.
132136 decode_sampling_top_k: Top-k sampling probability.
133137 vllm_config_path: Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.
138+ hf_access_token: Hugging Face access token for private models.
139+ tokenizer_path: Path to the tokenizer. If None, use hf_model_name.
140+ load_parameters_path: Path to load model parameters from.
141+ debug_sharding: Whether to debug shardings.
142+ seed: Random seed for sampling.
134143 """
135-
136144 # Prepare vLLM Arguments
137- vllm_args = {}
138- vllm_args ["additional_config" ] = {}
139-
140- # Core vLLM Arguments
141- vllm_args ["model" ] = hf_model_name
142- vllm_args ["max_model_len" ] = max_target_length
143- vllm_args ["tensor_parallel_size" ] = ici_tensor_parallelism
144- vllm_args ["data_parallel_size" ] = ici_data_parallelism
145- vllm_args ["enable_expert_parallel" ] = enable_expert_parallel
146- vllm_args ["hf_config_path" ] = hf_config_path
147- vllm_args ["gpu_memory_utilization" ] = gpu_memory_utilization
148-
149- # Prepare MaxText and sharding configs (Parallelism is dynamic)
150- vllm_args ["additional_config" ]["maxtext_config" ] = {
151- "model_name" : model_name ,
152- "max_target_length" : max_target_length ,
153- "weight_dtype" : "bfloat16" ,
154- "allow_split_physical_axes" : True ,
155- "debug_sharding" : debug_sharding ,
145+ vllm_args = {
146+ "model" : hf_model_name ,
147+ "max_model_len" : max_target_length ,
148+ "tensor_parallel_size" : ici_tensor_parallelism ,
149+ "data_parallel_size" : ici_data_parallelism ,
150+ "hf_config_path" : hf_config_path ,
151+ "gpu_memory_utilization" : gpu_memory_utilization ,
152+ "additional_config" : {
153+ "maxtext_config" : {
154+ "model_name" : model_name ,
155+ "weight_dtype" : "bfloat16" ,
156+ "allow_split_physical_axes" : True ,
157+ "debug_sharding" : debug_sharding ,
158+ },
159+ "sharding" : {
160+ "sharding_strategy" : {
161+ "enable_dp_attention" : enable_dp_attention ,
162+ },
163+ },
164+ },
156165 }
157- if load_parameters_path is not None :
166+
167+ if load_parameters_path :
158168 vllm_args ["additional_config" ]["maxtext_config" ]["load_parameters_path" ] = load_parameters_path
159169 else :
160170 vllm_args ["load_format" ] = "dummy"
161171
162- sharding_strategy = {
163- "enable_dp_attention" : enable_dp_attention ,
164- }
172+ enable_expert_parallel = ici_expert_parallelism > 1
165173 if enable_expert_parallel :
166- sharding_strategy ["expert_parallelism" ] = ici_expert_parallelism
167- vllm_args ["additional_config" ]["sharding" ] = {
168- "sharding_strategy" : sharding_strategy ,
169- }
170-
171- if enable_expert_parallel :
172- vllm_args ["additional_config" ]["sharding" ]["sharding_strategy" ].update ({"expert_parallelism" : ici_expert_parallelism })
173-
174- # Initialize and Run LLM
175- max_tokens = max_target_length - max_prefill_length
176- sampling_params = SamplingParams (
177- temperature = decode_sampling_temperature ,
178- max_tokens = max_tokens ,
179- top_k = decode_sampling_top_k ,
180- top_p = decode_sampling_nucleus_p ,
181- )
174+ vllm_args ["additional_config" ]["sharding" ]["sharding_strategy" ]["expert_parallelism" ] = ici_expert_parallelism
175+ vllm_args ["enable_expert_parallel" ] = enable_expert_parallel
182176
183- print (
184- f"Initializing LLM with DP={ vllm_args [ 'data_parallel_size' ] } , TP={ vllm_args [ 'tensor_parallel_size' ] } "
177+ max_logging . log (
178+ f"Initializing LLM with DP={ ici_data_parallelism } , TP={ ici_tensor_parallelism } "
185179 f"and EP={ ici_expert_parallelism if enable_expert_parallel else 0 } ..."
186180 )
187181
@@ -192,22 +186,62 @@ def decode_with_vllm(
192186 with nn_partitioning .axis_rules (vllm_config .logical_axis_rules ):
193187 llm = LLM (** vllm_args )
194188
195- print ("Generating output..." )
196- outputs = llm .generate ([prompt ], sampling_params )
189+ max_logging .log ("Generating output..." )
190+ tokenizer = transformers .AutoTokenizer .from_pretrained (
191+ tokenizer_path if tokenizer_path is not None else hf_model_name ,
192+ token = hf_access_token ,
193+ )
197194
198- # Print Outputs
195+ prompts = [prompt ]
196+ if use_chat_template :
197+ # Format the prompt using chat template if specified
198+ messages = [
199+ {"role" : "user" , "content" : prompt },
200+ ]
201+ input_with_chat_template = tokenizer .apply_chat_template (
202+ messages ,
203+ tokenize = False , # Set to False to get the string
204+ add_generation_prompt = True ,
205+ add_special_tokens = False , # Prevent adding special tokens
206+ )
207+ prompts = [input_with_chat_template ]
208+
209+ max_prompt_length = max (len (tokenizer .encode (p )) for p in prompts )
210+ max_tokens_to_generate = max_target_length - max_prompt_length
211+ if max_tokens_to_generate <= 0 :
212+ raise ValueError (
213+ f"max_target_length ({ max_target_length } ) must be greater than max_prompt_length ({ max_prompt_length } )"
214+ )
215+
216+ sampling_params = SamplingParams (
217+ temperature = decode_sampling_temperature ,
218+ max_tokens = max_tokens_to_generate ,
219+ top_k = decode_sampling_top_k ,
220+ top_p = decode_sampling_nucleus_p ,
221+ seed = seed ,
222+ )
223+
224+ outputs = llm .generate (prompts , sampling_params )
225+
226+ # max_logging.log Outputs
199227 for output in outputs :
200228 prompt = output .prompt
201229 generated_text = output .outputs [0 ].text
202- print (f"Prompt: { prompt !r } , Generated text: { generated_text !r } " )
230+ max_logging . log (f"Prompt: { prompt } , Generated text: { generated_text } " )
203231
204232
205233def decode_with_tunix (
206234 config : Config ,
207235 model : Any ,
208236 mesh : jax .sharding .Mesh ,
209237) -> None :
210- """Decode using vLLM with a MaxText model."""
238+ """Decode using vLLM with a MaxText model via Tunix adapter.
239+
240+ Args:
241+ config: MaxText config.
242+ model: The MaxText model instance.
243+ mesh: The JAX mesh for parallelism.
244+ """
211245 # Wrap the model for Tunix
212246 tunix_model = TunixMaxTextAdapter (base_model = model )
213247
@@ -235,6 +269,10 @@ def decode_with_tunix(
235269
236270 max_prompt_length = max (len (tokenizer .encode (p )) for p in prompts )
237271 max_tokens_to_generate = config .max_target_length - max_prompt_length
272+ if max_tokens_to_generate <= 0 :
273+ raise ValueError (
274+ f"max_target_length ({ config .max_target_length } ) must be greater than max_prompt_length ({ max_prompt_length } )"
275+ )
238276
239277 # Create vLLM rollout for inference
240278 rollout_config = base_rollout .RolloutConfig (
@@ -262,8 +300,8 @@ def decode_with_tunix(
262300
263301 # Generate text
264302 output = vllm_rollout .generate (prompts , rollout_config )
265- print (f"Prompt: { config .prompt } " )
266- print (f"Output: { output .text [0 ]} " )
303+ max_logging . log (f"Prompt: { config .prompt } " )
304+ max_logging . log (f"Output: { output .text [0 ]} " )
267305
268306
269307def main (argv : Sequence [str ]) -> None :
@@ -283,20 +321,22 @@ def main(argv: Sequence[str]) -> None:
283321 model_name = FLAGS .model_name ,
284322 hf_model_name = FLAGS .hf_model_name ,
285323 hf_config_path = FLAGS .hf_config_path ,
324+ hf_access_token = FLAGS .hf_access_token ,
325+ tokenizer_path = FLAGS .tokenizer_path ,
286326 load_parameters_path = FLAGS .load_parameters_path ,
287327 ici_data_parallelism = FLAGS .ici_data_parallelism ,
288328 ici_tensor_parallelism = FLAGS .ici_tensor_parallelism ,
289329 ici_expert_parallelism = FLAGS .ici_expert_parallelism ,
290330 enable_dp_attention = FLAGS .enable_dp_attention ,
291331 max_target_length = FLAGS .max_target_length ,
292- max_prefill_length = FLAGS .max_prefill_length ,
293332 gpu_memory_utilization = FLAGS .gpu_memory_utilization ,
294- enable_expert_parallel = FLAGS .enable_expert_parallel ,
295333 prompt = FLAGS .prompt ,
334+ use_chat_template = FLAGS .use_chat_template ,
296335 decode_sampling_temperature = FLAGS .decode_sampling_temperature ,
297336 decode_sampling_nucleus_p = FLAGS .decode_sampling_nucleus_p ,
298337 decode_sampling_top_k = FLAGS .decode_sampling_top_k ,
299338 debug_sharding = FLAGS .debug_sharding ,
339+ seed = FLAGS .seed ,
300340 )
301341
302342
0 commit comments