Skip to content

Commit 578c777

Browse files
committed
Add prompt length filtering
1 parent 496ed40 commit 578c777

1 file changed

Lines changed: 23 additions & 13 deletions

File tree

src/MaxText/rl/train_rl.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
from flax.linen import partitioning as nn_partitioning
6363
from jax.sharding import Mesh
6464
from orbax import checkpoint as ocp
65-
from pprint import pprint
6665
from transformers import AutoTokenizer
6766
from tunix.rl import rl_cluster as rl_cluster_lib
6867
from tunix.rl.rollout import base_rollout
@@ -304,20 +303,28 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
304303
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
305304

306305
# Load datasets
307-
dataset = get_dataset(
306+
train_dataset = get_dataset(
308307
model_tokenizer,
309308
trainer_config,
310309
train_data_dir,
311310
trainer_config.train_split,
312311
data_files=trainer_config.hf_train_files,
313312
dataset_name=trainer_config.dataset_name,
314-
).batch(trainer_config.batch_size)[: trainer_config.num_batches]
313+
)
315314

316-
if trainer_config.train_fraction == 1.0:
317-
train_dataset = dataset.repeat(trainer_config.num_epoch)
318-
else:
319-
train_dataset = dataset[: int(len(dataset) * trainer_config.train_fraction)]
320-
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
315+
def _filter_long_prompts(x):
316+
tokens = model_tokenizer.tokenize(x["prompts"])
317+
return len(tokens) <= trainer_config.max_prefill_predict_length
318+
319+
train_dataset = train_dataset.filter(_filter_long_prompts)
320+
dataset_size = int(trainer_config.num_batches * trainer_config.batch_size * trainer_config.train_fraction)
321+
train_dataset = train_dataset[:dataset_size]
322+
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
323+
324+
train_dataset = (
325+
train_dataset.to_iter_dataset()
326+
.batch(trainer_config.batch_size)
327+
)
321328

322329
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
323330
if not eval_dataset_name:
@@ -330,12 +337,15 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
330337
trainer_config.eval_split,
331338
data_files=trainer_config.hf_eval_files,
332339
dataset_name=eval_dataset_name,
333-
).batch(trainer_config.batch_size)[: trainer_config.num_test_batches]
340+
)
334341

335-
# Let's see how one batch of the dataset looks like!
336-
if trainer_config.debug.rl:
337-
for ele in train_dataset[:1]:
338-
pprint(ele)
342+
test_dataset = test_dataset.filter(_filter_long_prompts)
343+
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
344+
345+
test_dataset = (
346+
test_dataset.to_iter_dataset()
347+
.batch(trainer_config.batch_size)
348+
)
339349

340350
# Load reference model
341351
max_logging.log("Creating reference model and also meshes for reference and rollout")

0 commit comments

Comments
 (0)