Skip to content

Commit af14e43

Browse files
Merge pull request #3020 from AI-Hypercomputer:nicogrande/update-train-rl-args
PiperOrigin-RevId: 862978391
2 parents 9cf4088 + 58cff0b commit af14e43

3 files changed

Lines changed: 11 additions & 0 deletions

File tree

src/MaxText/configs/rl.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ swap_space_vllm_gb: 2
144144
decode_sampling_temperature: 0.9
145145
decode_sampling_top_k: 50
146146
decode_sampling_nucleus_p: 1.0
147+
# Optional sharding configuration for samplers
148+
enable_dp_attention: False
149+
# Performance tuning for samplers
150+
max_num_batched_tokens: null
151+
max_num_seqs: null
147152

148153
# ====== Checkpoint Configuration ======
149154
enable_checkpointing: True

src/MaxText/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,9 @@ class VLLM(BaseModel):
14851485
kv_cache_buffer: int = Field(256, description="Buffer for KV cache.")
14861486
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
14871487
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
1488+
enable_dp_attention: bool = Field(False, description="Enable the attn_dp mesh axis in vLLM.")
1489+
max_num_batched_tokens: Optional[int] = Field(None, description="Max number of batched tokens in vLLM.")
1490+
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
14881491
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
14891492
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
14901493

src/MaxText/rl/train_rl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
437437
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
438438
rollout_vllm_additional_config=rollout_additional_config,
439439
rollout_vllm_init_with_random_weights=True,
440+
rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention,
441+
rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens,
442+
rollout_vllm_max_num_seqs=trainer_config.max_num_seqs,
440443
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
441444
),
442445
)

0 commit comments

Comments
 (0)