Skip to content

Commit 01fbe6d

Browse files
Merge pull request #3409 from AI-Hypercomputer:nicogrande/rl-integration-test
PiperOrigin-RevId: 884647443
2 parents 93d1b83 + c45b59c commit 01fbe6d

3 files changed

Lines changed: 319 additions & 84 deletions

File tree

src/maxtext/configs/inference/vllm.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
base_config: "base.yml"
1616
attention: "vllm_rpa"
17+
model_call_mode: "inference"
18+
1719
# NNX required for vLLM integration
1820
enable_nnx: True
1921
# Avoid re-initializing JAX distributed system when using vLLM

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 96 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -283,39 +283,18 @@ def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
283283
return rollout_kwargs
284284

285285

286-
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
287-
"""
288-
Run RL training with the provided configuration.
289-
290-
Args:
291-
trainer_config: MaxText configuration for the trainer.
292-
sampler_config: MaxText configuration for the sampler.
293-
trainer_devices: JAX devices for the trainer.
294-
sampler_devices: JAX devices for the sampler.
295-
"""
296-
if not trainer_config.debug.rl:
297-
# Apply filter to suppress noisy logs
298-
noise_filter = max_logging.NoisyLogFilter()
299-
logging.getLogger().addFilter(noise_filter)
300-
absl_logging.get_absl_logger().addFilter(noise_filter)
301-
302-
max_logging.log("Starting RL Training")
303-
max_logging.log(f"Ensuring TensorBoard log directory exists: {trainer_config.tensorboard_dir}")
304-
if not epath.Path(trainer_config.tensorboard_dir).exists():
305-
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)
306-
307-
if not epath.Path(trainer_config.checkpoint_dir).exists():
308-
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)
309-
310-
# Number of training steps.
311-
max_train_steps = int(
286+
def get_max_train_steps(trainer_config):
287+
"""Calculate the total number of training steps."""
288+
return int(
312289
trainer_config.num_batches
313290
* trainer_config.rl.num_iterations
314291
* trainer_config.train_fraction
315292
* trainer_config.num_epoch
316293
)
317-
# ====== Data ======
318-
# Setup data directories
294+
295+
296+
def prepare_datasets(trainer_config, model_tokenizer):
297+
"""Setup and return train and test datasets."""
319298
home = os.path.expanduser("~") + "/"
320299
train_data_dir = f"{home}/data/train"
321300
test_data_dir = f"{home}/data/test"
@@ -324,9 +303,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
324303
if not os.path.exists(test_data_dir):
325304
os.makedirs(test_data_dir)
326305

327-
# Create model tokenizer
328-
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
329-
330306
# Load datasets
331307
if trainer_config.dataset_name == "huggingface:nvidia/OpenMathInstruct-2":
332308
import datasets # pylint: disable=import-outside-toplevel
@@ -335,7 +311,6 @@ def prepare_openinstructmath2_dataset(
335311
split: str = "train_1M",
336312
seed: int = 42,
337313
test_size: float = 0.05,
338-
output_key: str = "expected_answer",
339314
):
340315
"""Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
341316
max_logging.log(
@@ -420,41 +395,16 @@ def _filter_long_prompts(x):
420395
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
421396

422397
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
398+
return train_dataset, test_dataset
423399

424-
if trainer_config.debug.rl:
425-
# Let's see how one batch of the dataset looks like!
426-
if trainer_config.debug.rl:
427-
for i, ele in enumerate(train_dataset):
428-
if i >= 5:
429-
break
430-
pprint(ele)
431-
if trainer_config.debug.rl:
432-
for i, ele in enumerate(test_dataset):
433-
if i >= 5:
434-
break
435-
pprint(ele)
436-
437-
# Load reference model
400+
401+
def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices):
402+
"""Create reference and actor models and their respective meshes."""
438403
max_logging.log("Creating reference model and also meshes for reference and rollout")
439404
reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices)
440405
devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices)
441-
# if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh
442-
# else rollout_mesh uses sampler_devices
443406
rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes)
444-
if trainer_config.debug.rl:
445-
max_logging.log("Reference Model initialized successfully")
446-
nnx.display(reference_model)
447-
max_logging.log(f"Reference mesh shape: {reference_mesh.shape}")
448407

449-
# Sanity check that weights are loaded correctly.
450-
_maxtext_state_flatten = nnx.state(reference_model).flat_state()
451-
maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
452-
max_logging.log(
453-
f"maxtext_state_flatten[base.token_embedder.embedding].value=\
454-
{maxtext_state_flatten['base.token_embedder.embedding'][...]}"
455-
)
456-
457-
# TODO: @mazumdera: change this to use lora
458408
if trainer_config.load_checkpoint_only_once:
459409
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
460410
with reference_mesh:
@@ -467,11 +417,22 @@ def _filter_long_prompts(x):
467417
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
468418
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
469419

470-
if trainer_config.debug.rl:
471-
max_logging.log("Policy Model initialized successfully")
472-
nnx.display(actor_model)
473-
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")
474-
420+
return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh
421+
422+
423+
def create_rl_components(
424+
trainer_config,
425+
sampler_config,
426+
sampler_devices,
427+
actor_model,
428+
actor_mesh,
429+
reference_model,
430+
reference_mesh,
431+
rollout_mesh,
432+
model_tokenizer,
433+
max_train_steps,
434+
):
435+
"""Setup RL cluster, trainer, and optimizer."""
475436
# Setup optimizer
476437
optimizer = utils_rl.get_optimizer(trainer_config, max_train_steps)
477438

@@ -484,7 +445,6 @@ def _filter_long_prompts(x):
484445
micro_batch_size = None if trainer_config.micro_batch_size == -1 else trainer_config.micro_batch_size
485446

486447
# Setup metrics logging
487-
max_logging.log(f"Tensorboard logs directory: {trainer_config.tensorboard_dir}")
488448
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
489449
log_dir=trainer_config.tensorboard_dir, flush_every_n_steps=trainer_config.log_period
490450
)
@@ -502,25 +462,18 @@ def _filter_long_prompts(x):
502462
rollout_additional_config = None
503463
if trainer_config.vllm_additional_config:
504464
if isinstance(trainer_config.vllm_additional_config, dict):
505-
# It's already parsed into a dict
506465
rollout_additional_config = trainer_config.vllm_additional_config
507466
elif isinstance(trainer_config.vllm_additional_config, str):
508-
# It's a string, so we need to parse it
509467
try:
510468
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
511469
except json.JSONDecodeError as e:
512470
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e
513471

514-
max_logging.log(f"Parsed additional config: {rollout_additional_config}")
515-
516472
# We need to parse vLLM config to get the logical axis rules for the sampler config.
517473
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
518474
argv_list = ["", str(vllm_config_path), "log_config=False"]
519475
vllm_config = pyconfig.initialize(argv_list)
520476

521-
# RL Cluster config
522-
# Note that we use vLLM as the rollout engine.
523-
# and we are using Tensor Parallelism for rollout
524477
cluster_config = rl_cluster_lib.ClusterConfig(
525478
role_to_mesh={
526479
rl_cluster_lib.Role.ACTOR: actor_mesh,
@@ -538,15 +491,11 @@ def _filter_long_prompts(x):
538491
actor_optimizer=optimizer,
539492
eval_every_n_steps=trainer_config.eval_interval,
540493
max_steps=max_train_steps,
541-
# Micro batching
542494
mini_batch_size=trainer_config.batch_size,
543495
train_micro_batch_size=micro_batch_size,
544496
rollout_micro_batch_size=micro_batch_size,
545-
# Metrics logging
546497
metrics_logging_options=metrics_logging_options,
547-
# Profiling
548498
profiler_options=profiler_options,
549-
# Checkpoint saving
550499
checkpoint_root_directory=trainer_config.checkpoint_dir,
551500
checkpointing_options=checkpointing_options,
552501
),
@@ -580,6 +529,7 @@ def _filter_long_prompts(x):
580529
**get_rollout_kwargs_for_parallelism(sampler_config, len(sampler_devices)),
581530
),
582531
)
532+
583533
grpo_config = GrpoConfig(
584534
num_generations=trainer_config.rl.num_generations,
585535
num_iterations=trainer_config.rl.num_iterations,
@@ -596,9 +546,6 @@ def _filter_long_prompts(x):
596546
from tunix.perf import export as perf_export # pylint: disable=import-outside-toplevel
597547
from tunix.perf import metrics as perf_metrics # pylint: disable=import-outside-toplevel
598548

599-
max_logging.log(
600-
"enable_tunix_perf_metrics is True and tunix.perf modules are available, enabling Tunix-managed metrics."
601-
)
602549
perf_config = perf_metrics.PerfMetricsConfig()
603550
perf_config.custom_export_fn = perf_export.PerfMetricsExport.create_metrics_export_fn(cluster_config)
604551
rl_cluster_kwargs["perf_config"] = perf_config
@@ -628,9 +575,76 @@ def _filter_long_prompts(x):
628575
algo_config=grpo_config,
629576
)
630577

578+
return rl_cluster, rl_trainer, optimizer
579+
580+
581+
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
582+
"""
583+
Run RL training with the provided configuration.
584+
585+
Args:
586+
trainer_config: MaxText configuration for the trainer.
587+
sampler_config: MaxText configuration for the sampler.
588+
trainer_devices: JAX devices for the trainer.
589+
sampler_devices: JAX devices for the sampler.
590+
"""
591+
if not trainer_config.debug.rl:
592+
# Apply filter to suppress noisy logs
593+
noise_filter = max_logging.NoisyLogFilter()
594+
logging.getLogger().addFilter(noise_filter)
595+
absl_logging.get_absl_logger().addFilter(noise_filter)
596+
597+
max_logging.log("Starting RL Training")
598+
if not epath.Path(trainer_config.tensorboard_dir).exists():
599+
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)
600+
601+
if not epath.Path(trainer_config.checkpoint_dir).exists():
602+
epath.Path(trainer_config.checkpoint_dir).mkdir(parents=True)
603+
604+
max_train_steps = get_max_train_steps(trainer_config)
605+
606+
# Create model tokenizer
607+
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
608+
609+
train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer)
610+
611+
if trainer_config.debug.rl:
612+
for i, ele in enumerate(train_dataset):
613+
if i >= 5:
614+
break
615+
pprint(ele)
616+
for i, ele in enumerate(test_dataset):
617+
if i >= 5:
618+
break
619+
pprint(ele)
620+
621+
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes(
622+
trainer_config, sampler_config, trainer_devices, sampler_devices
623+
)
624+
625+
if trainer_config.debug.rl:
626+
max_logging.log("Reference Model initialized successfully")
627+
nnx.display(reference_model)
628+
max_logging.log(f"Reference mesh shape: {reference_mesh.shape}")
629+
max_logging.log("Policy Model initialized successfully")
630+
nnx.display(actor_model)
631+
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")
632+
633+
rl_cluster, rl_trainer, _ = create_rl_components(
634+
trainer_config,
635+
sampler_config,
636+
sampler_devices,
637+
actor_model,
638+
actor_mesh,
639+
reference_model,
640+
reference_mesh,
641+
rollout_mesh,
642+
model_tokenizer,
643+
max_train_steps,
644+
)
645+
631646
# Before we train the model, let's evaluate the model on the test set so we can
632647
# see the improvement post training.
633-
#
634648
(corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate(
635649
trainer_config,
636650
test_dataset,
@@ -639,11 +653,9 @@ def _filter_long_prompts(x):
639653
corr_lst=trainer_config.eval_corr_lst,
640654
make_lst=trainer_config.eval_make_lst,
641655
)
642-
# TODO: @mazumdera: Change this to max_logging.log once b/473703277 is resolved
643656
max_logging.warning(f"Pre RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
644657

645658
# Start training
646-
647659
if trainer_config.load_checkpoint_only_once:
648660
max_logging.log("Capturing reference model state before training.")
649661
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))

0 commit comments

Comments
 (0)