6262from flax .linen import partitioning as nn_partitioning
6363from jax .sharding import Mesh
6464from orbax import checkpoint as ocp
65- from pprint import pprint
6665from transformers import AutoTokenizer
6766from tunix .rl import rl_cluster as rl_cluster_lib
6867from tunix .rl .rollout import base_rollout
@@ -304,20 +303,28 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
304303 model_tokenizer = AutoTokenizer .from_pretrained (trainer_config .tokenizer_path )
305304
306305 # Load datasets
307- dataset = get_dataset (
306+ train_dataset = get_dataset (
308307 model_tokenizer ,
309308 trainer_config ,
310309 train_data_dir ,
311310 trainer_config .train_split ,
312311 data_files = trainer_config .hf_train_files ,
313312 dataset_name = trainer_config .dataset_name ,
314- ). batch ( trainer_config . batch_size )[: trainer_config . num_batches ]
313+ )
315314
316- if trainer_config .train_fraction == 1.0 :
317- train_dataset = dataset .repeat (trainer_config .num_epoch )
318- else :
319- train_dataset = dataset [: int (len (dataset ) * trainer_config .train_fraction )]
320- train_dataset = train_dataset .repeat (trainer_config .num_epoch )
315+ def _filter_long_prompts (x ):
316+ tokens = model_tokenizer .tokenize (x ["prompts" ])
317+ return len (tokens ) <= trainer_config .max_prefill_predict_length
318+
319+ train_dataset = train_dataset .filter (_filter_long_prompts )
320+ dataset_size = int (trainer_config .num_batches * trainer_config .batch_size * trainer_config .train_fraction )
321+ train_dataset = train_dataset [:dataset_size ]
322+ train_dataset = train_dataset .repeat (trainer_config .num_epoch )
323+
324+ train_dataset = (
325+ train_dataset .to_iter_dataset ()
326+ .batch (trainer_config .batch_size )
327+ )
321328
322329 eval_dataset_name = getattr (trainer_config , "eval_dataset_name" , None )
323330 if not eval_dataset_name :
@@ -330,12 +337,15 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
330337 trainer_config .eval_split ,
331338 data_files = trainer_config .hf_eval_files ,
332339 dataset_name = eval_dataset_name ,
333- ). batch ( trainer_config . batch_size )[: trainer_config . num_test_batches ]
340+ )
334341
335- # Let's see how one batch of the dataset looks like!
336- if trainer_config .debug .rl :
337- for ele in train_dataset [:1 ]:
338- pprint (ele )
342+ test_dataset = test_dataset .filter (_filter_long_prompts )
343+ test_dataset = test_dataset [: trainer_config .num_test_batches * trainer_config .batch_size ]
344+
345+ test_dataset = (
346+ test_dataset .to_iter_dataset ()
347+ .batch (trainer_config .batch_size )
348+ )
339349
340350 # Load reference model
341351 max_logging .log ("Creating reference model and also meshes for reference and rollout" )
0 commit comments