Skip to content

Commit fe3d200

Browse files
committed
Add the option to avoid loading the checkpoint twice
1 parent 570ee04 commit fe3d200

3 files changed

Lines changed: 26 additions & 5 deletions

File tree

src/MaxText/rl/train_rl.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,18 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
358358
)
359359

360360
# TODO: @mazumdera: change this to use lora
361-
# TODO: @xfgu: instead of restoring a second time from GCS, can we just copy reference_model
362-
# Load policy model
363-
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
364-
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
361+
if trainer_config.load_checkpoint_only_once:
362+
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
363+
with reference_mesh:
364+
actor_base_model = nnx.clone(reference_model.base)
365+
use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config
366+
actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings)
367+
actor_model.config = None
368+
actor_mesh = reference_mesh
369+
else:
370+
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
371+
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
372+
365373

366374
if trainer_config.debug.rl:
367375
max_logging.log("Policy Model initialized successfully")
@@ -530,11 +538,23 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
530538

531539
# Start training
532540

541+
if trainer_config.load_checkpoint_only_once:
542+
max_logging.log("Capturing reference model state before training.")
543+
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
544+
533545
max_logging.warning("Starting RL training...")
534546

535547
with reference_mesh, nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
536548
rl_trainer.train(train_dataset)
537549

550+
if trainer_config.load_checkpoint_only_once:
551+
max_logging.log("Checking if reference model state changed during training.")
552+
ref_state_after = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
553+
check = jax.tree_util.tree_map(jax.numpy.array_equal, ref_state_before, ref_state_after)
554+
if not jax.tree_util.tree_all(check):
555+
raise ValueError("Reference model parameters changed during training!")
556+
max_logging.log("Reference model parameters verified to be unchanged during training.")
557+
538558
max_logging.warning("RL Training Completed Successfully!")
539559

540560
# Let's evaluate our model!

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# RL Configuration
1616
# This config consolidates common parameters for RL training across different model sizes
1717

18-
base_config: "base.yml"
18+
base_config: "../base.yml"
1919

2020
# ====== Hardware =====
2121
trainer_devices_fraction: 0.5

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ class Checkpointing(BaseModel):
288288
lora_input_adapters_path: PathStr = Field("", description="Input GCS path for LoRA adapters.")
289289
load_full_state_path: PathStr = Field("", description="Loads the complete training state from a checkpoint path.")
290290
enable_checkpointing: bool = Field(True, description="If True, enables saving checkpoints during training.")
291+
load_checkpoint_only_once: bool = Field(False, description="If True, deep copy the reference model to the actor model.")
291292
async_checkpointing: bool = Field(True, description="If True, uses an asynchronous checkpointer for performance.")
292293
checkpoint_period: int = Field(10_000, description="The frequency (in steps) at which to save checkpoints.")
293294
max_num_checkpoints_to_keep: int | None = Field(None, description="Maximum number of checkpoints to keep.")

0 commit comments

Comments
 (0)