@@ -283,39 +283,18 @@ def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
283283 return rollout_kwargs
284284
285285
286- def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
287- """
288- Run RL training with the provided configuration.
289-
290- Args:
291- trainer_config: MaxText configuration for the trainer.
292- sampler_config: MaxText configuration for the sampler.
293- trainer_devices: JAX devices for the trainer.
294- sampler_devices: JAX devices for the sampler.
295- """
296- if not trainer_config .debug .rl :
297- # Apply filter to suppress noisy logs
298- noise_filter = max_logging .NoisyLogFilter ()
299- logging .getLogger ().addFilter (noise_filter )
300- absl_logging .get_absl_logger ().addFilter (noise_filter )
301-
302- max_logging .log ("Starting RL Training" )
303- max_logging .log (f"Ensuring TensorBoard log directory exists: { trainer_config .tensorboard_dir } " )
304- if not epath .Path (trainer_config .tensorboard_dir ).exists ():
305- epath .Path (trainer_config .tensorboard_dir ).mkdir (parents = True , exist_ok = True )
306-
307- if not epath .Path (trainer_config .checkpoint_dir ).exists ():
308- epath .Path (trainer_config .checkpoint_dir ).mkdir (parents = True )
309-
310- # Number of training steps.
311- max_train_steps = int (
286+ def get_max_train_steps (trainer_config ):
287+ """Calculate the total number of training steps."""
288+ return int (
312289 trainer_config .num_batches
313290 * trainer_config .rl .num_iterations
314291 * trainer_config .train_fraction
315292 * trainer_config .num_epoch
316293 )
317- # ====== Data ======
318- # Setup data directories
294+
295+
296+ def prepare_datasets (trainer_config , model_tokenizer ):
297+ """Setup and return train and test datasets."""
319298 home = os .path .expanduser ("~" ) + "/"
320299 train_data_dir = f"{ home } /data/train"
321300 test_data_dir = f"{ home } /data/test"
@@ -324,9 +303,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
324303 if not os .path .exists (test_data_dir ):
325304 os .makedirs (test_data_dir )
326305
327- # Create model tokenizer
328- model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
329-
330306 # Load datasets
331307 if trainer_config .dataset_name == "huggingface:nvidia/OpenMathInstruct-2" :
332308 import datasets # pylint: disable=import-outside-toplevel
@@ -335,7 +311,6 @@ def prepare_openinstructmath2_dataset(
335311 split : str = "train_1M" ,
336312 seed : int = 42 ,
337313 test_size : float = 0.05 ,
338- output_key : str = "expected_answer" ,
339314 ):
340315 """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
341316 max_logging .log (
@@ -420,41 +395,16 @@ def _filter_long_prompts(x):
420395 test_dataset = test_dataset [: trainer_config .num_test_batches * trainer_config .batch_size ]
421396
422397 test_dataset = test_dataset .to_iter_dataset ().batch (trainer_config .batch_size )
398+ return train_dataset , test_dataset
423399
424- if trainer_config .debug .rl :
425- # Let's see how one batch of the dataset looks like!
426- if trainer_config .debug .rl :
427- for i , ele in enumerate (train_dataset ):
428- if i >= 5 :
429- break
430- pprint (ele )
431- if trainer_config .debug .rl :
432- for i , ele in enumerate (test_dataset ):
433- if i >= 5 :
434- break
435- pprint (ele )
436-
437- # Load reference model
400+
401+ def create_models_and_meshes (trainer_config , sampler_config , trainer_devices , sampler_devices ):
402+ """Create reference and actor models and their respective meshes."""
438403 max_logging .log ("Creating reference model and also meshes for reference and rollout" )
439404 reference_model , reference_mesh = get_maxtext_model (trainer_config , trainer_devices )
440405 devices_array = maxtext_utils .create_device_mesh (sampler_config , sampler_devices )
441- # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh
442- # else rollout_mesh uses sampler_devices
443406 rollout_mesh = Mesh (devices_array , sampler_config .mesh_axes )
444- if trainer_config .debug .rl :
445- max_logging .log ("Reference Model initialized successfully" )
446- nnx .display (reference_model )
447- max_logging .log (f"Reference mesh shape: { reference_mesh .shape } " )
448407
449- # Sanity check that weights are loaded correctly.
450- _maxtext_state_flatten = nnx .state (reference_model ).flat_state ()
451- maxtext_state_flatten = {"." .join (str (key ) for key in keys ): v for keys , v in _maxtext_state_flatten }
452- max_logging .log (
453- f"maxtext_state_flatten[base.token_embedder.embedding].value=\
454- { maxtext_state_flatten ['base.token_embedder.embedding' ][...]} "
455- )
456-
457- # TODO: @mazumdera: change this to use lora
458408 if trainer_config .load_checkpoint_only_once :
459409 max_logging .log ("Creating policy model by copying reference model instead of restoring from checkpoint again." )
460410 with reference_mesh :
@@ -467,11 +417,22 @@ def _filter_long_prompts(x):
467417 max_logging .log ("Creating policy model with same config as reference model on trainer mesh" )
468418 actor_model , actor_mesh = get_maxtext_model (trainer_config , trainer_devices )
469419
470- if trainer_config .debug .rl :
471- max_logging .log ("Policy Model initialized successfully" )
472- nnx .display (actor_model )
473- max_logging .log (f"Policy mesh shape: { actor_mesh .shape } " )
474-
420+ return reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh
421+
422+
423+ def create_rl_components (
424+ trainer_config ,
425+ sampler_config ,
426+ sampler_devices ,
427+ actor_model ,
428+ actor_mesh ,
429+ reference_model ,
430+ reference_mesh ,
431+ rollout_mesh ,
432+ model_tokenizer ,
433+ max_train_steps ,
434+ ):
435+ """Setup RL cluster, trainer, and optimizer."""
475436 # Setup optimizer
476437 optimizer = utils_rl .get_optimizer (trainer_config , max_train_steps )
477438
@@ -484,7 +445,6 @@ def _filter_long_prompts(x):
484445 micro_batch_size = None if trainer_config .micro_batch_size == - 1 else trainer_config .micro_batch_size
485446
486447 # Setup metrics logging
487- max_logging .log (f"Tensorboard logs directory: { trainer_config .tensorboard_dir } " )
488448 metrics_logging_options = metrics_logger .MetricsLoggerOptions (
489449 log_dir = trainer_config .tensorboard_dir , flush_every_n_steps = trainer_config .log_period
490450 )
@@ -502,25 +462,18 @@ def _filter_long_prompts(x):
502462 rollout_additional_config = None
503463 if trainer_config .vllm_additional_config :
504464 if isinstance (trainer_config .vllm_additional_config , dict ):
505- # It's already parsed into a dict
506465 rollout_additional_config = trainer_config .vllm_additional_config
507466 elif isinstance (trainer_config .vllm_additional_config , str ):
508- # It's a string, so we need to parse it
509467 try :
510468 rollout_additional_config = json .loads (trainer_config .vllm_additional_config )
511469 except json .JSONDecodeError as e :
512470 raise ValueError (f"Failed to parse additional_config JSON: { e } " ) from e
513471
514- max_logging .log (f"Parsed additional config: { rollout_additional_config } " )
515-
516472 # We need to parse vLLM config to get the logical axis rules for the sampler config.
517473 vllm_config_path = os .path .join (MAXTEXT_CONFIGS_DIR , "inference" , "vllm.yml" )
518474 argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
519475 vllm_config = pyconfig .initialize (argv_list )
520476
521- # RL Cluster config
522- # Note that we use vLLM as the rollout engine.
523- # and we are using Tensor Parallelism for rollout
524477 cluster_config = rl_cluster_lib .ClusterConfig (
525478 role_to_mesh = {
526479 rl_cluster_lib .Role .ACTOR : actor_mesh ,
@@ -538,15 +491,11 @@ def _filter_long_prompts(x):
538491 actor_optimizer = optimizer ,
539492 eval_every_n_steps = trainer_config .eval_interval ,
540493 max_steps = max_train_steps ,
541- # Micro batching
542494 mini_batch_size = trainer_config .batch_size ,
543495 train_micro_batch_size = micro_batch_size ,
544496 rollout_micro_batch_size = micro_batch_size ,
545- # Metrics logging
546497 metrics_logging_options = metrics_logging_options ,
547- # Profiling
548498 profiler_options = profiler_options ,
549- # Checkpoint saving
550499 checkpoint_root_directory = trainer_config .checkpoint_dir ,
551500 checkpointing_options = checkpointing_options ,
552501 ),
@@ -580,6 +529,7 @@ def _filter_long_prompts(x):
580529 ** get_rollout_kwargs_for_parallelism (sampler_config , len (sampler_devices )),
581530 ),
582531 )
532+
583533 grpo_config = GrpoConfig (
584534 num_generations = trainer_config .rl .num_generations ,
585535 num_iterations = trainer_config .rl .num_iterations ,
@@ -596,9 +546,6 @@ def _filter_long_prompts(x):
596546 from tunix .perf import export as perf_export # pylint: disable=import-outside-toplevel
597547 from tunix .perf import metrics as perf_metrics # pylint: disable=import-outside-toplevel
598548
599- max_logging .log (
600- "enable_tunix_perf_metrics is True and tunix.perf modules are available, enabling Tunix-managed metrics."
601- )
602549 perf_config = perf_metrics .PerfMetricsConfig ()
603550 perf_config .custom_export_fn = perf_export .PerfMetricsExport .create_metrics_export_fn (cluster_config )
604551 rl_cluster_kwargs ["perf_config" ] = perf_config
@@ -628,9 +575,76 @@ def _filter_long_prompts(x):
628575 algo_config = grpo_config ,
629576 )
630577
578+ return rl_cluster , rl_trainer , optimizer
579+
580+
581+ def rl_train (trainer_config , sampler_config , trainer_devices , sampler_devices ):
582+ """
583+ Run RL training with the provided configuration.
584+
585+ Args:
586+ trainer_config: MaxText configuration for the trainer.
587+ sampler_config: MaxText configuration for the sampler.
588+ trainer_devices: JAX devices for the trainer.
589+ sampler_devices: JAX devices for the sampler.
590+ """
591+ if not trainer_config .debug .rl :
592+ # Apply filter to suppress noisy logs
593+ noise_filter = max_logging .NoisyLogFilter ()
594+ logging .getLogger ().addFilter (noise_filter )
595+ absl_logging .get_absl_logger ().addFilter (noise_filter )
596+
597+ max_logging .log ("Starting RL Training" )
598+ if not epath .Path (trainer_config .tensorboard_dir ).exists ():
599+ epath .Path (trainer_config .tensorboard_dir ).mkdir (parents = True , exist_ok = True )
600+
601+ if not epath .Path (trainer_config .checkpoint_dir ).exists ():
602+ epath .Path (trainer_config .checkpoint_dir ).mkdir (parents = True )
603+
604+ max_train_steps = get_max_train_steps (trainer_config )
605+
606+ # Create model tokenizer
607+ model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
608+
609+ train_dataset , test_dataset = prepare_datasets (trainer_config , model_tokenizer )
610+
611+ if trainer_config .debug .rl :
612+ for i , ele in enumerate (train_dataset ):
613+ if i >= 5 :
614+ break
615+ pprint (ele )
616+ for i , ele in enumerate (test_dataset ):
617+ if i >= 5 :
618+ break
619+ pprint (ele )
620+
621+ reference_model , reference_mesh , actor_model , actor_mesh , rollout_mesh = create_models_and_meshes (
622+ trainer_config , sampler_config , trainer_devices , sampler_devices
623+ )
624+
625+ if trainer_config .debug .rl :
626+ max_logging .log ("Reference Model initialized successfully" )
627+ nnx .display (reference_model )
628+ max_logging .log (f"Reference mesh shape: { reference_mesh .shape } " )
629+ max_logging .log ("Policy Model initialized successfully" )
630+ nnx .display (actor_model )
631+ max_logging .log (f"Policy mesh shape: { actor_mesh .shape } " )
632+
633+ rl_cluster , rl_trainer , _ = create_rl_components (
634+ trainer_config ,
635+ sampler_config ,
636+ sampler_devices ,
637+ actor_model ,
638+ actor_mesh ,
639+ reference_model ,
640+ reference_mesh ,
641+ rollout_mesh ,
642+ model_tokenizer ,
643+ max_train_steps ,
644+ )
645+
631646 # Before we train the model, let's evaluate the model on the test set so we can
632647 # see the improvement post training.
633- #
634648 (corr , total , accuracy , partial_accuracy , format_accuracy ), _ = evaluate (
635649 trainer_config ,
636650 test_dataset ,
@@ -639,11 +653,9 @@ def _filter_long_prompts(x):
639653 corr_lst = trainer_config .eval_corr_lst ,
640654 make_lst = trainer_config .eval_make_lst ,
641655 )
642- # TODO: @mazumdera: Change this to max_logging.log once b/473703277 is resolved
643656 max_logging .warning (f"Pre RL Training: { corr = } , { total = } , { accuracy = } %, { partial_accuracy = } %," f" { format_accuracy = } %" )
644657
645658 # Start training
646-
647659 if trainer_config .load_checkpoint_only_once :
648660 max_logging .log ("Capturing reference model state before training." )
649661 ref_state_before = nnx .to_pure_dict (nnx .state (reference_model .base , nnx .Param ))
0 commit comments