@@ -358,10 +358,18 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
358358 )
359359
360360 # TODO: @mazumdera: change this to use lora
361- # TODO: @xfgu: instead of restoring a second time from GCS, can we just copy reference_model
362- # Load policy model
363- max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
364- actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
361+ if trainer_config .load_checkpoint_only_once :
362+ max_logging .log ("Creating policy model by copying reference model instead of restoring from checkpoint again." )
363+ with reference_mesh :
364+ actor_base_model = nnx .clone (reference_model .base )
365+ use_no_op_mappings = "maxtext_config" in trainer_config .vllm_additional_config
366+ actor_model = TunixMaxTextAdapter (base_model = actor_base_model , use_no_op_mappings = use_no_op_mappings )
367+ actor_model .config = None
368+ actor_mesh = reference_mesh
369+ else :
370+ max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
371+ actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
372+
365373
366374 if trainer_config .debug .rl :
367375 max_logging .log ("Policy Model initialized successfully" )
@@ -530,11 +538,23 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
530538
531539 # Start training
532540
541+ if trainer_config .load_checkpoint_only_once :
542+ max_logging .log ("Capturing reference model state before training." )
543+ ref_state_before = nnx .to_pure_dict (nnx .state (reference_model .base , nnx .Param ))
544+
533545 max_logging .warning ("Starting RL training..." )
534546
535547 with reference_mesh , nn_partitioning .axis_rules (trainer_config .logical_axis_rules ):
536548 rl_trainer .train (train_dataset )
537549
550+ if trainer_config .load_checkpoint_only_once :
551+ max_logging .log ("Checking if reference model state changed during training." )
552+ ref_state_after = nnx .to_pure_dict (nnx .state (reference_model .base , nnx .Param ))
553+ check = jax .tree_util .tree_map (jax .numpy .array_equal , ref_state_before , ref_state_after )
554+ if not jax .tree_util .tree_all (check ):
555+ raise ValueError ("Reference model parameters changed during training!" )
556+ max_logging .log ("Reference model parameters verified to be unchanged during training." )
557+
538558 max_logging .warning ("RL Training Completed Successfully!" )
539559
540560 # Let's evaluate our model!
0 commit comments