Skip to content

Commit 023e416

Browse files
committed
Add option to start test_batch in train_rl from a specific index, also add default tokenizer_path for default model
1 parent ca7e2df commit 023e416

4 files changed

Lines changed: 7 additions & 1 deletion

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ micro_batch_size: -1
100100
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
101101
# increased to a max. of 330 (if batch size is 4).
102102
num_test_batches: 5 # 200
103+
test_batch_start_index: 0
103104
train_fraction: 1.0
104105

105106
eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,7 @@ class RLDataset(BaseModel):
16361636
batch_size: int = Field(1, description="Global batch size for the dataset loader in RL.")
16371637
num_batches: int = Field(4, description="Number of batches for RL training.")
16381638
num_test_batches: int = Field(5, description="Number of batches for RL evaluation.")
1639+
test_batch_start_index: int = Field(0, description="Start index for the test dataset")
16391640
train_fraction: float = Field(1.0, description="Fraction of the dataset to be used for training.")
16401641
micro_batch_size: int = Field(-1, description="Micro batch size for rollout and training.")
16411642

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,9 @@ def _filter_long_prompts(x):
416416
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
417417

418418
test_dataset = test_dataset.filter(_filter_long_prompts)
419-
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
419+
test_dataset = test_dataset[
420+
trainer_config.test_batch_start_index : trainer_config.num_test_batches * trainer_config.batch_size
421+
]
420422

421423
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
422424

src/maxtext/utils/globals.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
"olmo3-7b": "allenai/Olmo-3-7B-Instruct",
7575
"olmo3-7b-pt": "allenai/Olmo-3-1025-7B",
7676
"olmo3-32b": "allenai/Olmo-3-32B-Think",
77+
# "default" is not HF model, but adding to to avoid confusing warning about tokenizer_path
78+
"default": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers/tokenizer.llama2"),
7779
}
7880

7981
__all__ = [

0 commit comments

Comments
 (0)