Skip to content

Commit 8e0aaf5

Browse files
Merge pull request #3379 from AI-Hypercomputer:hengtaoguo-rl
PiperOrigin-RevId: 882201030
2 parents a2a2f34 + 241383e commit 8e0aaf5

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

tests/unit/rl_utils_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,89 @@ def test_with_incomplete_reasoning_tags(self):
107107
self.assertFalse(has_correct_format)
108108

109109

110+
class TestNormalizeFinalAnswer(unittest.TestCase):
111+
"""Tests for utils_rl.normalize_final_answer."""
112+
113+
@pytest.mark.cpu_only
114+
def test_comma_boxed_and_currency(self):
115+
# Comma-separated numbers, \\boxed{}, and leading $ are all normalized to plain integers
116+
self.assertEqual(utils_rl.normalize_final_answer("1,000"), "1000")
117+
self.assertEqual(utils_rl.normalize_final_answer("$1,000"), "1000")
118+
self.assertEqual(utils_rl.normalize_final_answer("\\boxed{1,000}"), "1000")
119+
120+
@pytest.mark.cpu_only
121+
def test_equation_splitting_and_unit_removal(self):
122+
# Expressions with '=' are split on '='; trailing unit words are stripped
123+
self.assertEqual(utils_rl.normalize_final_answer("x = 10"), "10")
124+
self.assertEqual(utils_rl.normalize_final_answer("total = 100 meters"), "100")
125+
self.assertEqual(utils_rl.normalize_final_answer("42 mph"), "42")
126+
127+
@pytest.mark.cpu_only
128+
def test_latex_wrappers(self):
129+
# \\text{}, \\textbf{}, and \\overline{} wrappers are removed, leaving inner content
130+
self.assertEqual(utils_rl.normalize_final_answer("\\text{hello}"), "hello")
131+
self.assertEqual(utils_rl.normalize_final_answer("\\textbf{42}"), "42")
132+
self.assertEqual(utils_rl.normalize_final_answer("\\overline{AB}"), "AB")
133+
134+
@pytest.mark.cpu_only
135+
def test_dollar_math_extraction(self):
136+
# Content inside $...$ is extracted
137+
self.assertEqual(utils_rl.normalize_final_answer("The answer is $\\frac{1}{2}$"), "\\frac{1}{2}")
138+
139+
@pytest.mark.cpu_only
140+
def test_shorthand_frac_and_sqrt(self):
141+
# Shorthand \\fracab and \\sqrta are expanded to their full LaTeX forms
142+
self.assertEqual(utils_rl.normalize_final_answer("\\fracab"), "\\frac{a}{b}")
143+
self.assertEqual(utils_rl.normalize_final_answer("\\sqrta"), "\\sqrt{a}")
144+
145+
146+
class TestMatchFormatApproximatelyScores(unittest.TestCase):
147+
"""Tests for utils_rl.match_format_approximately.
148+
149+
Each tag that appears exactly once contributes reward_partial_format_match (0.5).
150+
Each tag that is absent or appears more than once contributes penalty_incorrect_format (-0.5).
151+
With 4 tags the score ranges from -2.0 (all wrong) to 2.0 (all correct).
152+
"""
153+
154+
def setUp(self):
155+
self.config = _make_config()
156+
157+
def _score(self, completion):
158+
return utils_rl.match_format_approximately(None, completion, self.config)
159+
160+
@pytest.mark.cpu_only
161+
def test_score_all_tags_present_exactly_once(self):
162+
# All four tags present exactly once -> 4 * 0.5 = 2.0
163+
self.assertEqual(self._score(["<reasoning>think</reasoning><answer>42</answer>"])[0], 2.0)
164+
165+
@pytest.mark.cpu_only
166+
def test_score_no_tags_present(self):
167+
# No tags at all -> 4 * -0.5 = -2.0
168+
self.assertEqual(self._score(["The answer is 42."])[0], -2.0)
169+
170+
@pytest.mark.cpu_only
171+
def test_score_only_answer_tags_present(self):
172+
# Only <answer>...</answer> present -> 2 * 0.5 + 2 * -0.5 = 0.0
173+
self.assertEqual(self._score(["<answer>42</answer>"])[0], 0.0)
174+
175+
@pytest.mark.cpu_only
176+
def test_score_duplicate_reasoning_start_tag(self):
177+
# Duplicate <reasoning> tag -> 3 * 0.5 + 1 * -0.5 = 1.0
178+
self.assertEqual(self._score(["<reasoning><reasoning>think</reasoning><answer>42</answer>"])[0], 1.0)
179+
180+
@pytest.mark.cpu_only
181+
def test_score_multiple_completions(self):
182+
# Multiple completions at once -> one score per entry
183+
multi_completions = [
184+
"<reasoning>think</reasoning><answer>42</answer>", # 2.0
185+
"no tags here", # -2.0
186+
]
187+
scores = self._score(multi_completions)
188+
self.assertEqual(len(scores), 2)
189+
self.assertEqual(scores[0], 2.0)
190+
self.assertEqual(scores[1], -2.0)
191+
192+
110193
class TestExtractHashAnswer(unittest.TestCase):
111194
"""Tests for utils_rl.extract_hash_answer."""
112195

0 commit comments

Comments
 (0)