Skip to content

Commit 46ec3af

Browse files
Merge pull request #3243 from AI-Hypercomputer:nicogrande/update-train-rl
PiperOrigin-RevId: 875242798
2 parents f3d9f5c + 8f54544 commit 46ec3af

1 file changed

Lines changed: 17 additions & 16 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from absl import logging as absl_logging
6060
from etils import epath
6161
from flax import nnx
62-
from flax.linen import partitioning as nn_partitioning
6362
from jax.sharding import Mesh
6463
from orbax import checkpoint as ocp
6564
from pprint import pprint
@@ -489,6 +488,11 @@ def _filter_long_prompts(x):
489488

490489
max_logging.log(f"Parsed additional config: {rollout_additional_config}")
491490

491+
# We need to parse vLLM config to get the logical axis rules for the sampler config.
492+
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
493+
argv_list = ["", str(vllm_config_path), "log_config=False"]
494+
vllm_config = pyconfig.initialize(argv_list)
495+
492496
# RL Cluster config
493497
# Note that we use vLLM as the rollout engine.
494498
# and we are using Tensor Parallelism for rollout
@@ -501,6 +505,7 @@ def _filter_long_prompts(x):
501505
role_to_logical_axis_rule={
502506
rl_cluster_lib.Role.ACTOR: trainer_config.logical_axis_rules,
503507
rl_cluster_lib.Role.REFERENCE: trainer_config.logical_axis_rules,
508+
rl_cluster_lib.Role.ROLLOUT: vllm_config.logical_axis_rules,
504509
},
505510
rollout_engine="vllm",
506511
offload_to_cpu=False,
@@ -537,6 +542,9 @@ def _filter_long_prompts(x):
537542
rollout_vllm_enable_dp_attention=trainer_config.enable_dp_attention,
538543
rollout_vllm_max_num_batched_tokens=trainer_config.max_num_batched_tokens,
539544
rollout_vllm_max_num_seqs=trainer_config.max_num_seqs,
545+
rollout_vllm_kwargs={
546+
"hf_overrides": trainer_config.vllm_hf_overrides,
547+
},
540548
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
541549
),
542550
)
@@ -567,18 +575,13 @@ def _filter_long_prompts(x):
567575
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
568576
)
569577

570-
vllm_config_path = epath.Path(MAXTEXT_CONFIGS_DIR) / "inference/vllm.yml"
571-
argv_list = ["", str(vllm_config_path), "log_config=False"]
572-
vllm_config = pyconfig.initialize(argv_list)
573-
574-
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
575-
rl_cluster = rl_cluster_lib.RLCluster(
576-
actor=actor_model,
577-
reference=reference_model,
578-
tokenizer=model_tokenizer,
579-
cluster_config=cluster_config,
580-
**rl_cluster_kwargs,
581-
)
578+
rl_cluster = rl_cluster_lib.RLCluster(
579+
actor=actor_model,
580+
reference=reference_model,
581+
tokenizer=model_tokenizer,
582+
cluster_config=cluster_config,
583+
**rl_cluster_kwargs,
584+
)
582585

583586
# Create RL trainer
584587
max_logging.log("Setting up RL trainer...")
@@ -614,9 +617,7 @@ def _filter_long_prompts(x):
614617
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
615618

616619
max_logging.warning("Starting RL training...")
617-
618-
with reference_mesh, nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
619-
rl_trainer.train(train_dataset)
620+
rl_trainer.train(train_dataset)
620621

621622
if trainer_config.load_checkpoint_only_once:
622623
max_logging.log("Checking if reference model state changed during training.")

0 commit comments

Comments
 (0)