1919from MaxText import max_logging
2020
2121
22+ # Constants for normalization
23+ SUBSTITUTIONS = [
24+ ("an " , "" ),
25+ ("a " , "" ),
26+ (".$" , "$" ),
27+ ("\\ $" , "" ),
28+ (r"\ " , "" ),
29+ (" " , "" ),
30+ ("mbox" , "text" ),
31+ (",\\ text{and}" , "," ),
32+ ("\\ text{and}" , "," ),
33+ ("\\ text{m}" , "\\ text{}" ),
34+ ]
35+
36+ REMOVED_EXPRESSIONS = [
37+ "square" ,
38+ "ways" ,
39+ "integers" ,
40+ "dollars" ,
41+ "mph" ,
42+ "inches" ,
43+ "hours" ,
44+ "km" ,
45+ "units" ,
46+ "\\ ldots" ,
47+ "sue" ,
48+ "points" ,
49+ "feet" ,
50+ "minutes" ,
51+ "digits" ,
52+ "cents" ,
53+ "degrees" ,
54+ "cm" ,
55+ "gm" ,
56+ "pounds" ,
57+ "meters" ,
58+ "meals" ,
59+ "edges" ,
60+ "students" ,
61+ "childrentickets" ,
62+ "multiples" ,
63+ "\\ text{s}" ,
64+ "\\ text{.}" ,
65+ "\\ text{\n s}" ,
66+ "\\ text{}^2" ,
67+ "\\ text{}^3" ,
68+ "\\ text{\n }" ,
69+ "\\ text{}" ,
70+ r"\mathrm{th}" ,
71+ r"^\circ" ,
72+ r"^{\circ}" ,
73+ r"\;" ,
74+ r",\!" ,
75+ "{,}" ,
76+ '"' ,
77+ "\\ dots" ,
78+ ]
79+
80+
2281# Let's define a RegEx for checking whether the format matches.
2382#
2483def get_match_format_regex (tmvp_config ):
@@ -90,6 +149,47 @@ def match_format_approximately(prompts, completions, tmvp_config, **kargs):
90149 return scores
91150
92151
152+ def normalize_final_answer (final_answer : str ) -> str :
153+ """Normalize a final answer to a quantitative reasoning question.
154+
155+ Args:
156+ final_answer: The answer string to normalize
157+
158+ Returns:
159+ Normalized answer string
160+ """
161+ final_answer = final_answer .split ("=" )[- 1 ]
162+
163+ # Apply substitutions and removals
164+ for before , after in SUBSTITUTIONS :
165+ final_answer = final_answer .replace (before , after )
166+ for expr in REMOVED_EXPRESSIONS :
167+ final_answer = final_answer .replace (expr , "" )
168+
169+ # Extract and normalize LaTeX math
170+ final_answer = re .sub (r"(.*?)(\$)(.*?)(\$)(.*)" , "$\\ 3$" , final_answer )
171+ final_answer = re .sub (r"(\\text\{)(.*?)(\})" , "\\ 2" , final_answer )
172+ final_answer = re .sub (r"(\\textbf\{)(.*?)(\})" , "\\ 2" , final_answer )
173+ final_answer = re .sub (r"(\\overline\{)(.*?)(\})" , "\\ 2" , final_answer )
174+ final_answer = re .sub (r"(\\boxed\{)(.*)(\})" , "\\ 2" , final_answer )
175+
176+ # Normalize shorthand TeX:
177+ # \fracab -> \frac{a}{b}
178+ # \frac{abc}{bef} -> \frac{abc}{bef}
179+ # \fracabc -> \frac{a}{b}c
180+ # \sqrta -> \sqrt{a}
181+ # \sqrtab -> sqrt{a}b
182+ final_answer = re .sub (r"(frac)([^{])(.)" , "frac{\\ 2}{\\ 3}" , final_answer )
183+ final_answer = re .sub (r"(sqrt)([^{])" , "sqrt{\\ 2}" , final_answer )
184+ final_answer = final_answer .replace ("$" , "" )
185+
186+ # Normalize numbers
187+ if final_answer .replace ("," , "" ).isdigit ():
188+ final_answer = final_answer .replace ("," , "" )
189+
190+ return final_answer .strip ()
191+
192+
93193def check_answer (prompts , completions , answer , tmvp_config , ** kargs ):
94194 """
95195 Reward the model if the answer is correct. A reward is also given if the answer
@@ -105,6 +205,9 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
105205 if guess is None :
106206 scores .append (0 )
107207 continue
208+ if "DAPO" in tmvp_config .dataset_name :
209+ guess = normalize_final_answer (guess )
210+ true_answer = normalize_final_answer (true_answer )
108211 # Correct answer gets tmvp_config.reward_exact_format_match points!
109212 if guess == true_answer :
110213 score += tmvp_config .reward_exact_format_match
@@ -207,3 +310,68 @@ def get_optimizer(tmvp_config, max_train_steps):
207310 optimizer ,
208311 )
209312 return optimizer
313+
314+
315+ def process_data (dataset_name , model_tokenizer , template_config , tmvp_config , x ):
316+ """Function to process input dataset"""
317+
318+ def _to_str (val ):
319+ if isinstance (val , bytes ):
320+ return val .decode ("utf-8" )
321+ return str (val )
322+
323+ # Handle DAPO dataset schema
324+ # originally (prompt is a list, answer is in reward_model)
325+ # https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/viewer/default/train?row=0
326+ # but using https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed/viewer/all/train?row=1
327+ # so question is prompt and answer is solution
328+
329+ question = x .get ("question" , x .get ("prompt" ))
330+ answer = x .get ("answer" )
331+ if answer is None and "solution" in x :
332+ answer = x ["solution" ]
333+
334+ # Handle OpenMathInstruct-2
335+ if "problem" in x :
336+ question = x ["problem" ]
337+ if "expected_answer" in x :
338+ answer = x ["expected_answer" ]
339+
340+ # Handle AIME-2024
341+ if "extra_info" in x and isinstance (x ["extra_info" ], dict ) and "raw_problem" in x ["extra_info" ]:
342+ question = x ["extra_info" ]["raw_problem" ]
343+
344+ if "reward_model" in x and isinstance (x ["reward_model" ], dict ) and "ground_truth" in x ["reward_model" ]:
345+ answer = x ["reward_model" ]["ground_truth" ]
346+
347+ question = _to_str (question )
348+ answer = _to_str (answer )
349+
350+ if dataset_name == "gsm8k" :
351+ answer = extract_hash_answer (answer )
352+
353+ return {
354+ # passed to model forward pass
355+ "prompts" : model_tokenizer .apply_chat_template (
356+ [
357+ {
358+ "role" : "user" ,
359+ "content" : template_config ["TEMPLATE" ].format (
360+ system_prompt = template_config ["SYSTEM_PROMPT" ].format (
361+ reasoning_start_token = tmvp_config .reasoning_start_token ,
362+ reasoning_end_token = tmvp_config .reasoning_end_token ,
363+ solution_start_token = tmvp_config .solution_start_token ,
364+ solution_end_token = tmvp_config .solution_end_token ,
365+ ),
366+ question = question ,
367+ ),
368+ },
369+ ],
370+ tokenize = False ,
371+ add_generation_prompt = True ,
372+ ),
373+ # passed to reward functions
374+ "question" : question ,
375+ # passed to reward functions
376+ "answer" : answer ,
377+ }
0 commit comments