Skip to content

Commit 5a3ada4

Browse files
Merge pull request #3037 from AI-Hypercomputer:anisha-dapo2
PiperOrigin-RevId: 862813823
2 parents 56fc8af + b0bba52 commit 5a3ada4

3 files changed

Lines changed: 223 additions & 43 deletions

File tree

src/MaxText/configs/rl.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ enable_tunix_perf_metrics: False
9292
batch_size: 1
9393
# Increase `batch_size` and `MAX_STEPS` for better results.
9494
# num_batches: 3738
95-
num_batches: 4 # 200
95+
num_batches: 4
9696
# A batch can be split into multiple micro batches for memory management
9797
# and/or async sampling and training.
9898
micro_batch_size: -1
@@ -171,7 +171,8 @@ skip_jax_distributed_system: True
171171

172172
# # TODO(@mazumdera): fix this
173173
# Dataset Configuration
174-
dataset_name: 'gsm8k'
174+
dataset_name: 'gsm8k' # huggingface:open-r1/DAPO-Math-17k-Processed
175+
eval_dataset_name: 'gsm8k' # huggingface:BytedTsinghua-SIA/AIME-2024
175176
train_split: 'train'
176177
eval_split: 'test'
177178
tokenizer_type: 'huggingface'

src/MaxText/rl/train_rl.py

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -104,50 +104,47 @@ def get_maxtext_model(config, devices=None):
104104
return tunix_model, mesh
105105

106106

107-
def get_dataset(model_tokenizer, tmvp_config, data_dir, split="train") -> grain.MapDataset:
107+
def get_dataset(
108+
model_tokenizer, tmvp_config, data_dir, split="train", data_files=None, dataset_name=None
109+
) -> grain.MapDataset:
108110
"""Download data"""
109111
if not os.path.exists(data_dir):
110112
os.makedirs(data_dir)
111113

112-
data = tfds.data_source(
113-
tmvp_config.dataset_name,
114-
split=split,
115-
data_dir=data_dir,
116-
builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
117-
download=True,
118-
)
114+
if dataset_name is None:
115+
raise ValueError("dataset_name must be provided")
116+
117+
if dataset_name.startswith("huggingface:"):
118+
import datasets # pylint: disable=import-outside-toplevel
119+
120+
if data_files is None:
121+
hf_dataset_name = dataset_name.replace("huggingface:", "")
122+
data = datasets.load_dataset(hf_dataset_name, split=split, cache_dir=data_dir)
123+
if tmvp_config.debug.rl:
124+
max_logging.log(f"Loaded Hugging Face dataset {hf_dataset_name} with split {split}. Size: {len(data)}")
125+
else: # data_files have been provided, useful for using slices of large datasets like nvidia/OpenMathInstruct-2
126+
data = datasets.load_dataset(
127+
"parquet",
128+
data_files={tmvp_config.train_split: data_files},
129+
split=split,
130+
cache_dir=data_dir,
131+
)
132+
else:
133+
builder_kwargs = {"file_format": tfds.core.FileFormat.ARRAY_RECORD}
134+
data = tfds.data_source(
135+
dataset_name,
136+
split=split,
137+
data_dir=data_dir,
138+
builder_kwargs=builder_kwargs,
139+
download=True,
140+
)
119141

120142
template_config = load_template_from_file(tmvp_config.chat_template_path)
143+
121144
loaded_dataset = (
122145
grain.MapDataset.source(data)
123146
.shuffle(seed=tmvp_config.data_shuffle_seed)
124-
.map(
125-
lambda x: {
126-
# passed to model forward pass
127-
"prompts": model_tokenizer.apply_chat_template(
128-
[
129-
{
130-
"role": "user",
131-
"content": template_config["TEMPLATE"].format(
132-
system_prompt=template_config["SYSTEM_PROMPT"].format(
133-
reasoning_start_token=tmvp_config.reasoning_start_token,
134-
reasoning_end_token=tmvp_config.reasoning_end_token,
135-
solution_start_token=tmvp_config.solution_start_token,
136-
solution_end_token=tmvp_config.solution_end_token,
137-
),
138-
question=x["question"].decode("utf-8"),
139-
),
140-
},
141-
],
142-
tokenize=False,
143-
add_generation_prompt=True,
144-
),
145-
# passed to reward functions
146-
"question": x["question"].decode("utf-8"),
147-
# passed to reward functions
148-
"answer": utils_rl.extract_hash_answer(x["answer"].decode("utf-8")),
149-
}
150-
)
147+
.map(lambda x: utils_rl.process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x))
151148
)
152149
return loaded_dataset
153150

@@ -290,19 +287,33 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
290287
model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path)
291288

292289
# Load datasets
293-
dataset = get_dataset(model_tokenizer, trainer_config, train_data_dir, trainer_config.train_split).batch(
294-
trainer_config.batch_size
295-
)[: trainer_config.num_batches]
290+
dataset = get_dataset(
291+
model_tokenizer,
292+
trainer_config,
293+
train_data_dir,
294+
trainer_config.train_split,
295+
data_files=trainer_config.hf_train_files,
296+
dataset_name=trainer_config.dataset_name,
297+
).batch(trainer_config.batch_size)[: trainer_config.num_batches]
296298

297299
if trainer_config.train_fraction == 1.0:
298300
train_dataset = dataset.repeat(trainer_config.num_epoch)
299301
else:
300302
train_dataset = dataset[: int(len(dataset) * trainer_config.train_fraction)]
301303
train_dataset = train_dataset.repeat(trainer_config.num_epoch)
302304

303-
test_dataset = get_dataset(model_tokenizer, trainer_config, test_data_dir, trainer_config.eval_split).batch(
304-
trainer_config.batch_size
305-
)[: trainer_config.num_test_batches]
305+
eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
306+
if not eval_dataset_name:
307+
eval_dataset_name = trainer_config.dataset_name
308+
309+
test_dataset = get_dataset(
310+
model_tokenizer,
311+
trainer_config,
312+
test_data_dir,
313+
trainer_config.eval_split,
314+
data_files=trainer_config.hf_eval_files,
315+
dataset_name=eval_dataset_name,
316+
).batch(trainer_config.batch_size)[: trainer_config.num_test_batches]
306317

307318
# Let's see how one batch of the dataset looks like!
308319
if trainer_config.debug.rl:

src/MaxText/rl/utils_rl.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,65 @@
1919
from MaxText import max_logging
2020

2121

22+
# Constants for normalization
23+
SUBSTITUTIONS = [
24+
("an ", ""),
25+
("a ", ""),
26+
(".$", "$"),
27+
("\\$", ""),
28+
(r"\ ", ""),
29+
(" ", ""),
30+
("mbox", "text"),
31+
(",\\text{and}", ","),
32+
("\\text{and}", ","),
33+
("\\text{m}", "\\text{}"),
34+
]
35+
36+
REMOVED_EXPRESSIONS = [
37+
"square",
38+
"ways",
39+
"integers",
40+
"dollars",
41+
"mph",
42+
"inches",
43+
"hours",
44+
"km",
45+
"units",
46+
"\\ldots",
47+
"sue",
48+
"points",
49+
"feet",
50+
"minutes",
51+
"digits",
52+
"cents",
53+
"degrees",
54+
"cm",
55+
"gm",
56+
"pounds",
57+
"meters",
58+
"meals",
59+
"edges",
60+
"students",
61+
"childrentickets",
62+
"multiples",
63+
"\\text{s}",
64+
"\\text{.}",
65+
"\\text{\ns}",
66+
"\\text{}^2",
67+
"\\text{}^3",
68+
"\\text{\n}",
69+
"\\text{}",
70+
r"\mathrm{th}",
71+
r"^\circ",
72+
r"^{\circ}",
73+
r"\;",
74+
r",\!",
75+
"{,}",
76+
'"',
77+
"\\dots",
78+
]
79+
80+
2281
# Let's define a RegEx for checking whether the format matches.
2382
#
2483
def get_match_format_regex(tmvp_config):
@@ -90,6 +149,47 @@ def match_format_approximately(prompts, completions, tmvp_config, **kargs):
90149
return scores
91150

92151

152+
def normalize_final_answer(final_answer: str) -> str:
153+
"""Normalize a final answer to a quantitative reasoning question.
154+
155+
Args:
156+
final_answer: The answer string to normalize
157+
158+
Returns:
159+
Normalized answer string
160+
"""
161+
final_answer = final_answer.split("=")[-1]
162+
163+
# Apply substitutions and removals
164+
for before, after in SUBSTITUTIONS:
165+
final_answer = final_answer.replace(before, after)
166+
for expr in REMOVED_EXPRESSIONS:
167+
final_answer = final_answer.replace(expr, "")
168+
169+
# Extract and normalize LaTeX math
170+
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
171+
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
172+
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
173+
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
174+
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
175+
176+
# Normalize shorthand TeX:
177+
# \fracab -> \frac{a}{b}
178+
# \frac{abc}{bef} -> \frac{abc}{bef}
179+
# \fracabc -> \frac{a}{b}c
180+
# \sqrta -> \sqrt{a}
181+
# \sqrtab -> sqrt{a}b
182+
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
183+
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
184+
final_answer = final_answer.replace("$", "")
185+
186+
# Normalize numbers
187+
if final_answer.replace(",", "").isdigit():
188+
final_answer = final_answer.replace(",", "")
189+
190+
return final_answer.strip()
191+
192+
93193
def check_answer(prompts, completions, answer, tmvp_config, **kargs):
94194
"""
95195
Reward the model if the answer is correct. A reward is also given if the answer
@@ -105,6 +205,9 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
105205
if guess is None:
106206
scores.append(0)
107207
continue
208+
if "DAPO" in tmvp_config.dataset_name:
209+
guess = normalize_final_answer(guess)
210+
true_answer = normalize_final_answer(true_answer)
108211
# Correct answer gets tmvp_config.reward_exact_format_match points!
109212
if guess == true_answer:
110213
score += tmvp_config.reward_exact_format_match
@@ -207,3 +310,68 @@ def get_optimizer(tmvp_config, max_train_steps):
207310
optimizer,
208311
)
209312
return optimizer
313+
314+
315+
def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x):
316+
"""Function to process input dataset"""
317+
318+
def _to_str(val):
319+
if isinstance(val, bytes):
320+
return val.decode("utf-8")
321+
return str(val)
322+
323+
# Handle DAPO dataset schema
324+
# originally (prompt is a list, answer is in reward_model)
325+
# https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/viewer/default/train?row=0
326+
# but using https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed/viewer/all/train?row=1
327+
# so question is prompt and answer is solution
328+
329+
question = x.get("question", x.get("prompt"))
330+
answer = x.get("answer")
331+
if answer is None and "solution" in x:
332+
answer = x["solution"]
333+
334+
# Handle OpenMathInstruct-2
335+
if "problem" in x:
336+
question = x["problem"]
337+
if "expected_answer" in x:
338+
answer = x["expected_answer"]
339+
340+
# Handle AIME-2024
341+
if "extra_info" in x and isinstance(x["extra_info"], dict) and "raw_problem" in x["extra_info"]:
342+
question = x["extra_info"]["raw_problem"]
343+
344+
if "reward_model" in x and isinstance(x["reward_model"], dict) and "ground_truth" in x["reward_model"]:
345+
answer = x["reward_model"]["ground_truth"]
346+
347+
question = _to_str(question)
348+
answer = _to_str(answer)
349+
350+
if dataset_name == "gsm8k":
351+
answer = extract_hash_answer(answer)
352+
353+
return {
354+
# passed to model forward pass
355+
"prompts": model_tokenizer.apply_chat_template(
356+
[
357+
{
358+
"role": "user",
359+
"content": template_config["TEMPLATE"].format(
360+
system_prompt=template_config["SYSTEM_PROMPT"].format(
361+
reasoning_start_token=tmvp_config.reasoning_start_token,
362+
reasoning_end_token=tmvp_config.reasoning_end_token,
363+
solution_start_token=tmvp_config.solution_start_token,
364+
solution_end_token=tmvp_config.solution_end_token,
365+
),
366+
question=question,
367+
),
368+
},
369+
],
370+
tokenize=False,
371+
add_generation_prompt=True,
372+
),
373+
# passed to reward functions
374+
"question": question,
375+
# passed to reward functions
376+
"answer": answer,
377+
}

0 commit comments

Comments
 (0)