Skip to content

Commit c45b59c

Browse files
committed
adding integration test for GRPO.
1 parent 00ef5de commit c45b59c

3 files changed

Lines changed: 319 additions & 84 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
base_config: "base.yml"
1616
attention: "vllm_rpa"
17+
model_call_mode: "inference"
18+
1719
# NNX required for vLLM integration
1820
enable_nnx: True
1921
# Avoid re-initializing JAX distributed system when using vLLM

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 96 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -282,39 +282,18 @@ def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
282282
return rollout_kwargs
283283

284284

285-
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
286-
"""
287-
Run RL training with the provided configuration.
288-
289-
Args:
290-
trainer_config: MaxText configuration for the trainer.
291-
sampler_config: MaxText configuration for the sampler.
292-
trainer_devices: JAX devices for the trainer.
293-
sampler_devices: JAX devices for the sampler.
294-
"""
295-
if not trainer_config.debug.rl:
296-
# Apply filter to suppress noisy logs
297-
noise_filter = max_logging.NoisyLogFilter()
298-
logging.getLogger().addFilter(noise_filter)
299-
absl_logging.get_absl_logger().addFilter(noise_filter)
300-
301-
max_logging.log("Starting RL Training")
302-
max_logging.log(f"Ensuring TensorBoard log directory exists: {trainer_config.tensorboard_dir}")
303-
if not epath.Path(trainer_config.tensorboard_dir).exists():
304-
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)
305-
306-
if not epath.Path(trainer_config.checkpoint_dir).exists():
307-
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)
308-
309-
# Number of training steps.
310-
max_train_steps = int(
285+
def get_max_train_steps(trainer_config):
286+
"""Calculate the total number of training steps."""
287+
return int(
311288
trainer_config.num_batches
312289
* trainer_config.rl.num_iterations
313290
* trainer_config.train_fraction
314291
* trainer_config.num_epoch
315292
)
316-
# ====== Data ======
317-
# Setup data directories
293+
294+
295+
def prepare_datasets(trainer_config, model_tokenizer):
296+
"""Setup and return train and test datasets."""
318297
home = os.path.expanduser("~") + "/"
319298
train_data_dir = f"{home}/data/train"
320299
test_data_dir = f"{home}/data/test"
@@ -323,9 +302,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
323302
if not os.path.exists(test_data_dir):
324303
os.makedirs(test_data_dir)
325304

326-
# Create model tokenizer
327-
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
328-
329305
# Load datasets
330306
if trainer_config.dataset_name == "huggingface:nvidia/OpenMathInstruct-2":
331307
import datasets # pylint: disable=import-outside-toplevel
@@ -334,7 +310,6 @@ def prepare_openinstructmath2_dataset(
334310
split: str = "train_1M",
335311
seed: int = 42,
336312
test_size: float = 0.05,
337-
output_key: str = "expected_answer",
338313
):
339314
"""Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
340315
max_logging.log(
@@ -419,41 +394,16 @@ def _filter_long_prompts(x):
419394
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
420395

421396
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
397+
return train_dataset, test_dataset
422398

423-
if trainer_config.debug.rl:
424-
# Let's see how one batch of the dataset looks like!
425-
if trainer_config.debug.rl:
426-
for i, ele in enumerate(train_dataset):
427-
if i >= 5:
428-
break
429-
pprint(ele)
430-
if trainer_config.debug.rl:
431-
for i, ele in enumerate(test_dataset):
432-
if i >= 5:
433-
break
434-
pprint(ele)
435-
436-
# Load reference model
399+
400+
def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices):
401+
"""Create reference and actor models and their respective meshes."""
437402
max_logging.log("Creating reference model and also meshes for reference and rollout")
438403
reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices)
439404
devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices)
440-
# if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh
441-
# else rollout_mesh uses sampler_devices
442405
rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes)
443-
if trainer_config.debug.rl:
444-
max_logging.log("Reference Model initialized successfully")
445-
nnx.display(reference_model)
446-
max_logging.log(f"Reference mesh shape: {reference_mesh.shape}")
447406

448-
# Sanity check that weights are loaded correctly.
449-
_maxtext_state_flatten = nnx.state(reference_model).flat_state()
450-
maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
451-
max_logging.log(
452-
f"maxtext_state_flatten[base.token_embedder.embedding].value=\
453-
{maxtext_state_flatten['base.token_embedder.embedding'][...]}"
454-
)
455-
456-
# TODO: @mazumdera: change this to use lora
457407
if trainer_config.load_checkpoint_only_once:
458408
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
459409
with reference_mesh:
@@ -466,11 +416,22 @@ def _filter_long_prompts(x):
466416
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
467417
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
468418

469-
if trainer_config.debug.rl:
470-
max_logging.log("Policy Model initialized successfully")
471-
nnx.display(actor_model)
472-
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")
473-
419+
return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh
420+
421+
422+
def create_rl_components(
423+
trainer_config,
424+
sampler_config,
425+
sampler_devices,
426+
actor_model,
427+
actor_mesh,
428+
reference_model,
429+
reference_mesh,
430+
rollout_mesh,
431+
model_tokenizer,
432+
max_train_steps,
433+
):
434+
"""Setup RL cluster, trainer, and optimizer."""
474435
# Setup optimizer
475436
optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps)
476437

@@ -483,7 +444,6 @@ def _filter_long_prompts(x):
483444
micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size
484445

485446
# Setup metrics logging
486-
max_logging.log(f"Tensorboard logs directory: {trainer_config.tensorboard_dir}")
487447
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
488448
log_dir=trainer_config.tensorboard_dir, flush_every_n_steps=trainer_config.log_period
489449
)
@@ -501,25 +461,18 @@ def _filter_long_prompts(x):
501461
rollout_additional_config = None
502462
if trainer_config.vllm_additional_config:
503463
if isinstance(trainer_config.vllm_additional_config, dict):
504-
# It's already parsed into a dict
505464
rollout_additional_config = trainer_config.vllm_additional_config
506465
elif isinstance(trainer_config.vllm_additional_config, str):
507-
# It's a string, so we need to parse it
508466
try:
509467
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
510468
except json.JSONDecodeError as e:
511469
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e
512470

513-
max_logging.log(f"Parsed additional config: {rollout_additional_config}")
514-
515471
# We need to parse vLLM config to get the logical axis rules for the sampler config.
516472
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
517473
argv_list = ["", str(vllm_config_path), "log_config=False"]
518474
vllm_config = pyconfig.initialize(argv_list)
519475

520-
# RL Cluster config
521-
# Note that we use vLLM as the rollout engine.
522-
# and we are using Tensor Parallelism for rollout
523476
cluster_config = rl_cluster_lib.ClusterConfig(
524477
role_to_mesh={
525478
rl_cluster_lib.Role.ACTOR: actor_mesh,
@@ -537,15 +490,11 @@ def _filter_long_prompts(x):
537490
actor_optimizer=optimizer,
538491
eval_every_n_steps=trainer_config.eval_interval,
539492
max_steps=max_train_steps,
540-
# Micro batching
541493
mini_batch_size=trainer_config.batch_size,
542494
train_micro_batch_size=micro_batch_size,
543495
rollout_micro_batch_size=micro_batch_size,
544-
# Metrics logging
545496
metrics_logging_options=metrics_logging_options,
546-
# Profiling
547497
profiler_options=profiler_options,
548-
# Checkpoint saving
549498
checkpoint_root_directory=trainer_config.checkpoint_dir,
550499
checkpointing_options=checkpointing_options,
551500
),
@@ -579,6 +528,7 @@ def _filter_long_prompts(x):
579528
**get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)),
580529
),
581530
)
531+
582532
grpo_config = GrpoConfig(
583533
num_generations=trainer_config.rl.num_generations,
584534
num_iterations=trainer_config.rl.num_iterations,
@@ -595,9 +545,6 @@ def _filter_long_prompts(x):
595545
from tunix.perf import export as perf_export # pylint: disable=import-outside-toplevel
596546
from tunix.perf import metrics as perf_metrics # pylint: disable=import-outside-toplevel
597547

598-
max_logging.log(
599-
"enable_tunix_perf_metrics is True and tunix.perf modules are available, enabling Tunix-managed metrics."
600-
)
601548
perf_config = perf_metrics.PerfMetricsConfig()
602549
perf_config.custom_export_fn = perf_export.PerfMetricsExport.create_metrics_export_fn(cluster_config)
603550
rl_cluster_kwargs["perf_config"] = perf_config
@@ -627,9 +574,76 @@ def _filter_long_prompts(x):
627574
algo_config=grpo_config,
628575
)
629576

577+
return rl_cluster, rl_trainer, optimizer
578+
579+
580+
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
581+
"""
582+
Run RL training with the provided configuration.
583+
584+
Args:
585+
trainer_config: MaxText configuration for the trainer.
586+
sampler_config: MaxText configuration for the sampler.
587+
trainer_devices: JAX devices for the trainer.
588+
sampler_devices: JAX devices for the sampler.
589+
"""
590+
if not trainer_config.debug.rl:
591+
# Apply filter to suppress noisy logs
592+
noise_filter = max_logging.NoisyLogFilter()
593+
logging.getLogger().addFilter(noise_filter)
594+
absl_logging.get_absl_logger().addFilter(noise_filter)
595+
596+
max_logging.log("Starting RL Training")
597+
if not epath.Path(trainer_config.tensorboard_dir).exists():
598+
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)
599+
600+
if not epath.Path(trainer_config.checkpoint_dir).exists():
601+
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)
602+
603+
max_train_steps = get_max_train_steps(trainer_config)
604+
605+
# Create model tokenizer
606+
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
607+
608+
train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer)
609+
610+
if trainer_config.debug.rl:
611+
for i, ele in enumerate(train_dataset):
612+
if i >= 5:
613+
break
614+
pprint(ele)
615+
for i, ele in enumerate(test_dataset):
616+
if i >= 5:
617+
break
618+
pprint(ele)
619+
620+
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes(
621+
trainer_config, sampler_config, trainer_devices, sampler_devices
622+
)
623+
624+
if trainer_config.debug.rl:
625+
max_logging.log("Reference Model initialized successfully")
626+
nnx.display(reference_model)
627+
max_logging.log(f"Reference mesh shape: {reference_mesh.shape}")
628+
max_logging.log("Policy Model initialized successfully")
629+
nnx.display(actor_model)
630+
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")
631+
632+
rl_cluster, rl_trainer, _ = create_rl_components(
633+
trainer_config,
634+
sampler_config,
635+
sampler_devices,
636+
actor_model,
637+
actor_mesh,
638+
reference_model,
639+
reference_mesh,
640+
rollout_mesh,
641+
model_tokenizer,
642+
max_train_steps,
643+
)
644+
630645
# Before we train the model, let's evaluate the model on the test set so we can
631646
# see the improvement post training.
632-
#
633647
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
634648
trainer_config,
635649
test_dataset,
@@ -638,11 +652,9 @@ def _filter_long_prompts(x):
638652
corr_lst=trainer_config.eval_corr_lst,
639653
make_lst=trainer_config.eval_make_lst,
640654
)
641-
# TODO: @mazumdera: Change this to max_logging.log once b/473703277 is resolved
642655
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
643656

644657
# Start training
645-
646658
if trainer_config.load_checkpoint_only_once:
647659
max_logging.log("Capturing reference model state before training.")
648660
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))

0 commit comments

Comments
 (0)