@@ -85,9 +85,9 @@ def get_maxtext_model(config, devices=None):
8585 """
8686 Load MaxText model with Tunix adapter.
8787 # Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
88- # To create a scanned checkpoint, you can use /maxtext/src/MaxText/utils/ckpt_conversion /to_maxtext.py and if
88+ # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion /to_maxtext.py and if
8989 # using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False`
90- # python src/MaxText/utils/ckpt_conversion /to_maxtext.py \
90+ # python src/MaxText/checkpoint_conversion /to_maxtext.py \
9191 # --model_name="gemma2-2b" \
9292 # --base_output_directory="/path/to/your/output/directory" \
9393 # --scan_layers=True \
@@ -304,20 +304,25 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
304304 model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
305305
306306 # Load datasets
307- dataset = get_dataset (
307+ train_dataset = get_dataset (
308308 model_tokenizer ,
309309 trainer_config ,
310310 train_data_dir ,
311311 trainer_config .train_split ,
312312 data_files = trainer_config .hf_train_files ,
313313 dataset_name = trainer_config .dataset_name ,
314- ). batch ( trainer_config . batch_size )[: trainer_config . num_batches ]
314+ )
315315
316- if trainer_config .train_fraction == 1.0 :
317- train_dataset = dataset .repeat (trainer_config .num_epoch )
318- else :
319- train_dataset = dataset [: int (len (dataset ) * trainer_config .train_fraction )]
320- train_dataset = train_dataset .repeat (trainer_config .num_epoch )
316+ def _filter_long_prompts (x ):
317+ tokens = model_tokenizer .tokenize (x ["prompts" ])
318+ return len (tokens ) <= trainer_config .max_prefill_predict_length
319+
320+ train_dataset = train_dataset .filter (_filter_long_prompts )
321+ dataset_size = int (trainer_config .num_batches * trainer_config .batch_size * trainer_config .train_fraction )
322+ train_dataset = train_dataset [:dataset_size ]
323+ train_dataset = train_dataset .repeat (trainer_config .num_epoch )
324+
325+ train_dataset = train_dataset .to_iter_dataset ().batch (trainer_config .batch_size )
321326
322327 eval_dataset_name = getattr (trainer_config , "eval_dataset_name" , None )
323328 if not eval_dataset_name :
@@ -330,12 +335,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
330335 trainer_config .eval_split ,
331336 data_files = trainer_config .hf_eval_files ,
332337 dataset_name = eval_dataset_name ,
333- ). batch ( trainer_config . batch_size )[: trainer_config . num_test_batches ]
338+ )
334339
335- # Let's see how one batch of the dataset looks like!
336- if trainer_config .debug . rl :
337- for ele in train_dataset [: 1 ]:
338- pprint ( ele )
340+ test_dataset = test_dataset . filter ( _filter_long_prompts )
341+ test_dataset = test_dataset [: trainer_config .num_test_batches * trainer_config . batch_size ]
342+
343+ test_dataset = test_dataset . to_iter_dataset (). batch ( trainer_config . batch_size )
339344
340345 # Load reference model
341346 max_logging .log ("Creating reference model and also meshes for reference and rollout" )
@@ -358,10 +363,17 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
358363 )
359364
360365 # TODO: @mazumdera: change this to use lora
361- # TODO: @xfgu: instead of restoring a second time from GCS, can we just copy reference_model
362- # Load policy model
363- max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
364- actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
366+ if trainer_config .load_checkpoint_only_once :
367+ max_logging .log ("Creating policy model by copying reference model instead of restoring from checkpoint again." )
368+ with reference_mesh :
369+ actor_base_model = nnx .clone (reference_model .base )
370+ use_no_op_mappings = "maxtext_config" in trainer_config .vllm_additional_config
371+ actor_model = TunixMaxTextAdapter (base_model = actor_base_model , use_no_op_mappings = use_no_op_mappings )
372+ actor_model .config = None
373+ actor_mesh = reference_mesh
374+ else :
375+ max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
376+ actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
365377
366378 if trainer_config .debug .rl :
367379 max_logging .log ("Policy Model initialized successfully" )
@@ -487,7 +499,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
487499 "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
488500 )
489501
490- vllm_config_path = epath . Path (MAXTEXT_CONFIGS_DIR ) / "inference/ vllm.yml"
502+ vllm_config_path = os . path . join (MAXTEXT_CONFIGS_DIR , "inference" , " vllm.yml")
491503 argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
492504 vllm_config = pyconfig .initialize (argv_list )
493505
@@ -529,11 +541,23 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
529541
530542 # Start training
531543
544+ if trainer_config .load_checkpoint_only_once :
545+ max_logging .log ("Capturing reference model state before training." )
546+ ref_state_before = nnx .to_pure_dict (nnx .state (reference_model .base , nnx .Param ))
547+
532548 max_logging .warning ("Starting RL training..." )
533549
534550 with reference_mesh , nn_partitioning .axis_rules (trainer_config .logical_axis_rules ):
535551 rl_trainer .train (train_dataset )
536552
553+ if trainer_config .load_checkpoint_only_once :
554+ max_logging .log ("Checking if reference model state changed during training." )
555+ ref_state_after = nnx .to_pure_dict (nnx .state (reference_model .base , nnx .Param ))
556+ check = jax .tree_util .tree_map (jax .numpy .array_equal , ref_state_before , ref_state_after )
557+ if not jax .tree_util .tree_all (check ):
558+ raise ValueError ("Reference model parameters changed during training!" )
559+ max_logging .log ("Reference model parameters verified to be unchanged during training." )
560+
537561 max_logging .warning ("RL Training Completed Successfully!" )
538562
539563 # Let's evaluate our model!
0 commit comments