Skip to content

Commit 0a5cce3

Browse files
Merge pull request #3059 from AI-Hypercomputer:anisha-split-openmath
PiperOrigin-RevId: 874640099
2 parents e228e8a + bc9b464 commit 0a5cce3

5 files changed

Lines changed: 296 additions & 49 deletions

File tree

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ RUN pip install vllm-tpu
3535

3636
RUN pip install --no-deps qwix==0.1.4
3737

38+
RUN pip install math-verify==0.9.0
39+
3840
RUN if [ "$MODE" = "post-training-experimental" ]; then \
3941
pip uninstall -y jax jaxlib libtpu && \
4042
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \

dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ RUN pip install -e /tpu-inference --no-cache-dir
4040

4141
RUN pip install --no-deps qwix==0.1.4
4242

43+
RUN pip install math-verify==0.9.0
44+
4345
RUN if [ "$MODE" = "post-training-experimental" ]; then \
4446
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \
4547
pip uninstall -y jax jaxlib libtpu && \

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717
RL Evaluation Module.
1818
"""
19+
from math_verify import parse
1920
from tqdm.auto import tqdm
2021
from tunix.rl.rollout.base_rollout import RolloutConfig
2122

@@ -99,7 +100,6 @@ def score_responses(tmvp_config, question, responses, answer):
99100
Tuple of (is_correct, is_partially_correct, has_correct_format)
100101
"""
101102
match_format = utils_rl.get_match_format_regex(tmvp_config)
102-
match_numbers = utils_rl.get_match_numbers_regex(tmvp_config)
103103

104104
if tmvp_config.debug.rl:
105105
max_logging.log("========================================")
@@ -114,22 +114,32 @@ def score_responses(tmvp_config, question, responses, answer):
114114

115115
for response in responses:
116116
# Extract numerical response
117-
extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000"
118-
117+
extracted_response = guess.group(1) if (guess := match_format.search(response)) is not None else "-1000000"
119118
if tmvp_config.debug.rl:
120119
max_logging.log(f"Evaluation extracted_response: {extracted_response}")
121120

122121
# Check exact correctness
123122
try:
124-
# Remove ',' and '$' then convert to float
125-
val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip())
126-
val_answer = float(answer.replace(",", "").replace("$", "").strip())
127-
is_correct = val_extracted == val_answer
128-
129-
# Check partial correctness (within 10%)
130-
ratio = val_extracted / val_answer
131-
if 0.9 <= ratio <= 1.1:
132-
is_partially_correct = True
123+
# Fix LaTeX escaping issues for both ground truth and extracted answer
124+
norm_answer = utils_rl.fix_latex_escaping(answer)
125+
norm_extracted = utils_rl.fix_latex_escaping(extracted_response)
126+
# Normalize Normalize for certain datasets and parse
127+
if "DAPO" in tmvp_config.dataset_name or "OpenMathInstruct" in tmvp_config.dataset_name:
128+
norm_extracted = utils_rl.normalize_final_answer(norm_extracted).strip()
129+
norm_answer = utils_rl.normalize_final_answer(answer).strip()
130+
is_correct = utils_rl.math_verify_func([utils_rl.boxed(norm_answer)], [utils_rl.boxed(norm_extracted)])[0] > 0.1
131+
if tmvp_config.debug.rl:
132+
# is_correct is a tuple, if first value is 1.0 means it's a match;
133+
# 0.0 means a mismatch. e.g. (0.0, (['3', '3'], ['3/5', '\\frac{3}{5}']))
134+
max_logging.log(f"Result is_correct: {is_correct}")
135+
136+
val_extracted = parse(utils_rl.boxed(norm_extracted))
137+
val_answer = parse(utils_rl.boxed(norm_answer))
138+
139+
# Check partial correctness if values can be extracted (within 10%)
140+
if val_extracted and val_answer:
141+
ratio = (val_extracted[0] + utils_rl.EPSILON) / (val_answer[0] + utils_rl.EPSILON)
142+
is_partially_correct = 0.9 <= ratio <= 1.1
133143

134144
except Exception as e:
135145
if tmvp_config.debug.rl:

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -304,14 +304,82 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
304304
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
305305

306306
# Load datasets
307-
train_dataset = get_dataset(
308-
model_tokenizer,
309-
trainer_config,
310-
train_data_dir,
311-
trainer_config.train_split,
312-
data_files=trainer_config.hf_train_files,
313-
dataset_name=trainer_config.dataset_name,
314-
)
307+
if trainer_config.dataset_name == "huggingface:nvidia/OpenMathInstruct-2":
308+
import datasets # pylint: disable=import-outside-toplevel
309+
310+
def prepare_openinstructmath2_dataset(
311+
split: str = "train_1M",
312+
seed: int = 42,
313+
test_size: float = 0.05,
314+
output_key: str = "expected_answer",
315+
):
316+
"""Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
317+
max_logging.log(
318+
"WARNING: For reproducible experiments, preprocess the dataset once and "
319+
"define your own HfDataset subclass that directly uses the preprocessed datasets."
320+
)
321+
322+
# Load the original dataset
323+
original_ds = datasets.load_dataset(
324+
"parquet",
325+
data_files={trainer_config.train_split: trainer_config.hf_train_files},
326+
split=split,
327+
cache_dir=train_data_dir,
328+
)
329+
330+
# Split into train and validation sets using HF's train_test_split
331+
split_ds = original_ds.train_test_split(test_size=test_size, seed=seed)
332+
333+
return {
334+
"train": split_ds["train"],
335+
"validation": split_ds["test"],
336+
}
337+
338+
split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M"
339+
splits = prepare_openinstructmath2_dataset(split=split_name)
340+
template_config = load_template_from_file(trainer_config.chat_template_path)
341+
342+
train_dataset = (
343+
grain.MapDataset.source(splits["train"])
344+
.shuffle(seed=trainer_config.data_shuffle_seed)
345+
.map(
346+
lambda x: utils_rl.process_data(
347+
trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x
348+
)
349+
)
350+
)
351+
352+
test_dataset = (
353+
grain.MapDataset.source(splits["validation"])
354+
.shuffle(seed=trainer_config.data_shuffle_seed)
355+
.map(
356+
lambda x: utils_rl.process_data(
357+
trainer_config.dataset_name, model_tokenizer, template_config, trainer_config, x
358+
)
359+
)
360+
)
361+
else:
362+
train_dataset = get_dataset(
363+
model_tokenizer,
364+
trainer_config,
365+
train_data_dir,
366+
trainer_config.train_split,
367+
data_files=trainer_config.hf_train_files,
368+
dataset_name=trainer_config.dataset_name,
369+
)
370+
371+
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
372+
if not eval_dataset_name:
373+
eval_dataset_name = trainer_config.dataset_name
374+
375+
test_dataset = get_dataset(
376+
model_tokenizer,
377+
trainer_config,
378+
test_data_dir,
379+
trainer_config.eval_split,
380+
data_files=trainer_config.hf_eval_files,
381+
dataset_name=eval_dataset_name,
382+
)
315383

316384
def _filter_long_prompts(x):
317385
tokens = model_tokenizer.tokenize(x["prompts"])
@@ -324,24 +392,24 @@ def _filter_long_prompts(x):
324392

325393
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)
326394

327-
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
328-
if not eval_dataset_name:
329-
eval_dataset_name = trainer_config.dataset_name
330-
331-
test_dataset = get_dataset(
332-
model_tokenizer,
333-
trainer_config,
334-
test_data_dir,
335-
trainer_config.eval_split,
336-
data_files=trainer_config.hf_eval_files,
337-
dataset_name=eval_dataset_name,
338-
)
339-
340395
test_dataset = test_dataset.filter(_filter_long_prompts)
341396
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]
342397

343398
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)
344399

400+
if trainer_config.debug.rl:
401+
# Let's see how one batch of the dataset looks like!
402+
if trainer_config.debug.rl:
403+
for i, ele in enumerate(train_dataset):
404+
if i >= 5:
405+
break
406+
pprint(ele)
407+
if trainer_config.debug.rl:
408+
for i, ele in enumerate(test_dataset):
409+
if i >= 5:
410+
break
411+
pprint(ele)
412+
345413
# Load reference model
346414
max_logging.log("Creating reference model and also meshes for reference and rollout")
347415
reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices)
@@ -499,7 +567,7 @@ def _filter_long_prompts(x):
499567
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
500568
)
501569

502-
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
570+
vllm_config_path = epath.Path(MAXTEXT_CONFIGS_DIR) / "inference/vllm.yml"
503571
argv_list = ["", str(vllm_config_path), "log_config=False"]
504572
vllm_config = pyconfig.initialize(argv_list)
505573

0 commit comments

Comments
 (0)