Skip to content

Commit 2ade63f

Browse files
Merge pull request #3215 from AI-Hypercomputer:anisha-rl-refactor-fix
PiperOrigin-RevId: 874181757
2 parents 3737eb3 + f2e762f commit 2ade63f

1 file changed

Lines changed: 43 additions & 19 deletions

File tree

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def get_maxtext_model(config, devices=None):
8585
"""
8686
Load MaxText model with Tunix adapter.
8787
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
88-
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py and if
88+
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
8989
# using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False`
90-
# python src/MaxText/utils/ckpt_conversion/to_maxtext.py \
90+
# python src/MaxText/checkpoint_conversion/to_maxtext.py \
9191
# --model_name="gemma2-2b" \
9292
# --base_output_directory="/path/to/your/output/directory" \
9393
# --scan_layers=True \
@@ -304,20 +304,25 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
304304
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
305305

306306
# Load datasets
307-
dataset = get_dataset(
307+
train_dataset = get_dataset(
308308
model_tokenizer,
309309
trainer_config,
310310
train_data_dir,
311311
trainer_config.train_split,
312312
data_files=trainer_config.hf_train_files,
313313
dataset_name=trainer_config.dataset_name,
314-
).batch(trainer_config.batch_size)[: trainer_config.num_batches]
314+
)
315315

316-
if trainer_config.train_fraction == 1.0:
317-
train_dataset = dataset.repeat(trainer_config.num_epoch)
318-
else:
319-
train_dataset = dataset[: int(len(dataset) * trainer_config.train_fraction)]
320-
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
316+
def _filter_long_prompts(x):
317+
tokens = model_tokenizer.tokenize(x["prompts"])
318+
return len(tokens) <= trainer_config.max_prefill_predict_length
319+
320+
train_dataset = train_dataset.filter(_filter_long_prompts)
321+
dataset_size = int(trainer_config.num_batches * trainer_config.batch_size * trainer_config.train_fraction)
322+
train_dataset = train_dataset[:dataset_size]
323+
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
324+
325+
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
321326

322327
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
323328
if not eval_dataset_name:
@@ -330,12 +335,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
330335
trainer_config.eval_split,
331336
data_files=trainer_config.hf_eval_files,
332337
dataset_name=eval_dataset_name,
333-
).batch(trainer_config.batch_size)[: trainer_config.num_test_batches]
338+
)
334339

335-
# Let's see how one batch of the dataset looks like!
336-
if trainer_config.debug.rl:
337-
for ele in train_dataset[:1]:
338-
pprint(ele)
340+
test_dataset = test_dataset.filter(_filter_long_prompts)
341+
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
342+
343+
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
339344

340345
# Load reference model
341346
max_logging.log("Creating reference model and also meshes for reference and rollout")
@@ -358,10 +363,17 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
358363
)
359364

360365
# TODO: @mazumdera: change this to use lora
361-
# TODO: @xfgu: instead of restoring a second time from GCS, can we just copy reference_model
362-
# Load policy model
363-
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
364-
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
366+
if trainer_config.load_checkpoint_only_once:
367+
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
368+
with reference_mesh:
369+
actor_base_model = nnx.clone(reference_model.base)
370+
use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config
371+
actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings)
372+
actor_model.config = None
373+
actor_mesh = reference_mesh
374+
else:
375+
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
376+
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
365377

366378
if trainer_config.debug.rl:
367379
max_logging.log("Policy Model initialized successfully")
@@ -487,7 +499,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
487499
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
488500
)
489501

490-
vllm_config_path = epath.Path(MAXTEXT_CONFIGS_DIR) / "inference/vllm.yml"
502+
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
491503
argv_list = ["", str(vllm_config_path), "log_config=False"]
492504
vllm_config = pyconfig.initialize(argv_list)
493505

@@ -529,11 +541,23 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
529541

530542
# Start training
531543

544+
if trainer_config.load_checkpoint_only_once:
545+
max_logging.log("Capturing reference model state before training.")
546+
ref_state_before = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
547+
532548
max_logging.warning("Starting RL training...")
533549

534550
with reference_mesh, nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
535551
rl_trainer.train(train_dataset)
536552

553+
if trainer_config.load_checkpoint_only_once:
554+
max_logging.log("Checking if reference model state changed during training.")
555+
ref_state_after = nnx.to_pure_dict(nnx.state(reference_model.base, nnx.Param))
556+
check = jax.tree_util.tree_map(jax.numpy.array_equal, ref_state_before, ref_state_after)
557+
if not jax.tree_util.tree_all(check):
558+
raise ValueError("Reference model parameters changed during training!")
559+
max_logging.log("Reference model parameters verified to be unchanged during training.")
560+
537561
max_logging.warning("RL Training Completed Successfully!")
538562

539563
# Let's evaluate our model!

0 commit comments

Comments
 (0)