Skip to content

Commit 58cff0b

Browse files
committed
adding new sharding and performance args for vllm.
1 parent 941d46a commit 58cff0b

3 files changed

Lines changed: 11 additions & 0 deletions

File tree

src/MaxText/configs/rl.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ swap_space_vllm_gb: 2
144144
decode_sampling_temperature: 0.9
145145
decode_sampling_top_k: 50
146146
decode_sampling_nucleus_p: 1.0
147+
# Optional sharding configuration for samplers
148+
enable_dp_attention: False
149+
# Performance tuning for samplers
150+
max_num_batched_tokens: null
151+
max_num_seqs: null
147152

148153
# ====== Checkpoint Configuration ======
149154
enable_checkpointing: True

src/MaxText/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,9 @@ class VLLM(BaseModel):
14841484
kv_cache_buffer: int = Field(256, description="Buffer for KV cache.")
14851485
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
14861486
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
1487+
enable_dp_attention: bool = Field(False, description="Enable the attn_dp mesh axis in vLLM.")
1488+
max_num_batched_tokens: Optional[int] = Field(None, description="Max number of batched tokens in vLLM.")
1489+
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
14871490
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
14881491
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
14891492

src/MaxText/rl/train_rl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
426426
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
427427
rollout_vllm_additional_config=rollout_additional_config,
428428
rollout_vllm_init_with_random_weights=True,
429+
rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention,
430+
rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens,
431+
rollout_vllm_max_num_seqs=trainer_config.max_num_seqs,
429432
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
430433
),
431434
)

0 commit comments

Comments
 (0)