5959from absl import logging as absl_logging
6060from etils import epath
6161from flax import nnx
62- from flax .linen import partitioning as nn_partitioning
6362from jax .sharding import Mesh
6463from orbax import checkpoint as ocp
6564from pprint import pprint
@@ -489,6 +488,11 @@ def _filter_long_prompts(x):
489488
490489 max_logging .log (f"Parsed additional config: { rollout_additional_config } " )
491490
491+ # We need to parse vLLM config to get the logical axis rules for the sampler config.
492+ vllm_config_path = os .path .join (MAXTEXT_CONFIGS_DIR , "inference" , "vllm.yml" )
493+ argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
494+ vllm_config = pyconfig .initialize (argv_list )
495+
492496 # RL Cluster config
493497 # Note that we use vLLM as the rollout engine.
494498 # and we are using Tensor Parallelism for rollout
@@ -501,6 +505,7 @@ def _filter_long_prompts(x):
501505 role_to_logical_axis_rule = {
502506 rl_cluster_lib .Role .ACTOR : trainer_config .logical_axis_rules ,
503507 rl_cluster_lib .Role .REFERENCE : trainer_config .logical_axis_rules ,
508+ rl_cluster_lib .Role .ROLLOUT : vllm_config .logical_axis_rules ,
504509 },
505510 rollout_engine = "vllm" ,
506511 offload_to_cpu = False ,
@@ -537,6 +542,9 @@ def _filter_long_prompts(x):
537542 rollout_vllm_enable_dp_attention = trainer_config .enable_dp_attention ,
538543 rollout_vllm_max_num_batched_tokens = trainer_config .max_num_batched_tokens ,
539544 rollout_vllm_max_num_seqs = trainer_config .max_num_seqs ,
545+ rollout_vllm_kwargs = {
546+ "hf_overrides" : trainer_config .vllm_hf_overrides ,
547+ },
540548 ** get_rollout_kwargs_for_data_parallelism (sampler_config , len (sampler_devices )),
541549 ),
542550 )
@@ -567,18 +575,13 @@ def _filter_long_prompts(x):
567575 "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
568576 )
569577
570- vllm_config_path = epath .Path (MAXTEXT_CONFIGS_DIR ) / "inference/vllm.yml"
571- argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
572- vllm_config = pyconfig .initialize (argv_list )
573-
574- with nn_partitioning .axis_rules (vllm_config .logical_axis_rules ):
575- rl_cluster = rl_cluster_lib .RLCluster (
576- actor = actor_model ,
577- reference = reference_model ,
578- tokenizer = model_tokenizer ,
579- cluster_config = cluster_config ,
580- ** rl_cluster_kwargs ,
581- )
578+ rl_cluster = rl_cluster_lib .RLCluster (
579+ actor = actor_model ,
580+ reference = reference_model ,
581+ tokenizer = model_tokenizer ,
582+ cluster_config = cluster_config ,
583+ ** rl_cluster_kwargs ,
584+ )
582585
583586 # Create RL trainer
584587 max_logging .log ("Setting up RL trainer..." )
@@ -614,9 +617,7 @@ def _filter_long_prompts(x):
614617 ref_state_before = nnx .to_pure_dict (nnx .state (reference_model .base , nnx .Param ))
615618
616619 max_logging .warning ("Starting RL training..." )
617-
618- with reference_mesh , nn_partitioning .axis_rules (trainer_config .logical_axis_rules ):
619- rl_trainer .train (train_dataset )
620+ rl_trainer .train (train_dataset )
620621
621622 if trainer_config .load_checkpoint_only_once :
622623 max_logging .log ("Checking if reference model state changed during training." )
0 commit comments