Skip to content

Commit 5cd1acb

Browse files
Merge pull request #3419 from AI-Hypercomputer:anisha-test-batch-start
PiperOrigin-RevId: 884662772
2 parents b842fe3 + 023e416 commit 5cd1acb

4 files changed

Lines changed: 7 additions & 1 deletion

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ micro_batch_size: -1
100100
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
101101
# increased to a max. of 330 (if batch size is 4).
102102
num_test_batches: 5 # 200
103+
test_batch_start_index: 0
103104
train_fraction: 1.0
104105

105106
eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,6 +1637,7 @@ class RLDataset(BaseModel):
16371637
batch_size: int = Field(1, description="Global batch size for the dataset loader in RL.")
16381638
num_batches: int = Field(4, description="Number of batches for RL training.")
16391639
num_test_batches: int = Field(5, description="Number of batches for RL evaluation.")
1640+
test_batch_start_index: int = Field(0, description="Start index for the test dataset")
16401641
train_fraction: float = Field(1.0, description="Fraction of the dataset to be used for training.")
16411642
micro_batch_size: int = Field(-1, description="Micro batch size for rollout and training.")
16421643

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,9 @@ def _filter_long_prompts(x):
392392
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
393393

394394
test_dataset = test_dataset.filter(_filter_long_prompts)
395-
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
395+
test_dataset = test_dataset[
396+
trainer_config.test_batch_start_index : trainer_config.num_test_batches * trainer_config.batch_size
397+
]
396398

397399
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
398400
return train_dataset, test_dataset

src/maxtext/utils/globals.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
"olmo3-7b": "allenai/Olmo-3-7B-Instruct",
7575
"olmo3-7b-pt": "allenai/Olmo-3-1025-7B",
7676
"olmo3-32b": "allenai/Olmo-3-32B-Think",
77+
# "default" is not HF model, but adding to to avoid confusing warning about tokenizer_path
78+
"default": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers/tokenizer.llama2"),
7779
}
7880

7981
__all__ = [

0 commit comments

Comments
 (0)