@@ -304,14 +304,82 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
304304 model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
305305
306306 # Load datasets
307- train_dataset = get_dataset (
308- model_tokenizer ,
309- trainer_config ,
310- train_data_dir ,
311- trainer_config .train_split ,
312- data_files = trainer_config .hf_train_files ,
313- dataset_name = trainer_config .dataset_name ,
314- )
307+ if trainer_config .dataset_name == "huggingface:nvidia/OpenMathInstruct-2" :
308+ import datasets # pylint: disable=import-outside-toplevel
309+
310+ def prepare_openinstructmath2_dataset (
311+ split : str = "train_1M" ,
312+ seed : int = 42 ,
313+ test_size : float = 0.05 ,
314+ output_key : str = "expected_answer" ,
315+ ):
316+ """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
317+ max_logging .log (
318+ "WARNING: For reproducible experiments, preprocess the dataset once and "
319+ "define your own HfDataset subclass that directly uses the preprocessed datasets."
320+ )
321+
322+ # Load the original dataset
323+ original_ds = datasets .load_dataset (
324+ "parquet" ,
325+ data_files = {trainer_config .train_split : trainer_config .hf_train_files },
326+ split = split ,
327+ cache_dir = train_data_dir ,
328+ )
329+
330+ # Split into train and validation sets using HF's train_test_split
331+ split_ds = original_ds .train_test_split (test_size = test_size , seed = seed )
332+
333+ return {
334+ "train" : split_ds ["train" ],
335+ "validation" : split_ds ["test" ],
336+ }
337+
338+ split_name = trainer_config .train_split if trainer_config .train_split != "train" else "train_1M"
339+ splits = prepare_openinstructmath2_dataset (split = split_name )
340+ template_config = load_template_from_file (trainer_config .chat_template_path )
341+
342+ train_dataset = (
343+ grain .MapDataset .source (splits ["train" ])
344+ .shuffle (seed = trainer_config .data_shuffle_seed )
345+ .map (
346+ lambda x : utils_rl .process_data (
347+ trainer_config .dataset_name , model_tokenizer , template_config , trainer_config , x
348+ )
349+ )
350+ )
351+
352+ test_dataset = (
353+ grain .MapDataset .source (splits ["validation" ])
354+ .shuffle (seed = trainer_config .data_shuffle_seed )
355+ .map (
356+ lambda x : utils_rl .process_data (
357+ trainer_config .dataset_name , model_tokenizer , template_config , trainer_config , x
358+ )
359+ )
360+ )
361+ else :
362+ train_dataset = get_dataset (
363+ model_tokenizer ,
364+ trainer_config ,
365+ train_data_dir ,
366+ trainer_config .train_split ,
367+ data_files = trainer_config .hf_train_files ,
368+ dataset_name = trainer_config .dataset_name ,
369+ )
370+
371+ eval_dataset_name = getattr (trainer_config , "eval_dataset_name" , None )
372+ if not eval_dataset_name :
373+ eval_dataset_name = trainer_config .dataset_name
374+
375+ test_dataset = get_dataset (
376+ model_tokenizer ,
377+ trainer_config ,
378+ test_data_dir ,
379+ trainer_config .eval_split ,
380+ data_files = trainer_config .hf_eval_files ,
381+ dataset_name = eval_dataset_name ,
382+ )
315383
316384 def _filter_long_prompts (x ):
317385 tokens = model_tokenizer .tokenize (x ["prompts" ])
@@ -324,24 +392,24 @@ def _filter_long_prompts(x):
324392
325393 train_dataset = train_dataset .to_iter_dataset ().batch (trainer_config .batch_size )
326394
327- eval_dataset_name = getattr (trainer_config , "eval_dataset_name" , None )
328- if not eval_dataset_name :
329- eval_dataset_name = trainer_config .dataset_name
330-
331- test_dataset = get_dataset (
332- model_tokenizer ,
333- trainer_config ,
334- test_data_dir ,
335- trainer_config .eval_split ,
336- data_files = trainer_config .hf_eval_files ,
337- dataset_name = eval_dataset_name ,
338- )
339-
340395 test_dataset = test_dataset .filter (_filter_long_prompts )
341396 test_dataset = test_dataset [: trainer_config .num_test_batches * trainer_config .batch_size ]
342397
343398 test_dataset = test_dataset .to_iter_dataset ().batch (trainer_config .batch_size )
344399
400+ if trainer_config .debug .rl :
401+ # Let's see how one batch of the dataset looks like!
402+ if trainer_config .debug .rl :
403+ for i , ele in enumerate (train_dataset ):
404+ if i >= 5 :
405+ break
406+ pprint (ele )
407+ if trainer_config .debug .rl :
408+ for i , ele in enumerate (test_dataset ):
409+ if i >= 5 :
410+ break
411+ pprint (ele )
412+
345413 # Load reference model
346414 max_logging .log ("Creating reference model and also meshes for reference and rollout" )
347415 reference_model , reference_mesh = get_maxtext_model (trainer_config , trainer_devices )
@@ -499,7 +567,7 @@ def _filter_long_prompts(x):
499567 "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
500568 )
501569
502- vllm_config_path = os . path . join (MAXTEXT_CONFIGS_DIR , "inference" , " vllm.yml")
570+ vllm_config_path = epath . Path (MAXTEXT_CONFIGS_DIR ) / "inference/ vllm.yml"
503571 argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
504572 vllm_config = pyconfig .initialize (argv_list )
505573
0 commit comments