Skip to content

Commit 00ef5de

Browse files
Merge pull request #3402 from AI-Hypercomputer:hengtaoguo-rl
PiperOrigin-RevId: 882856046
2 parents e69f8ab + 13531a7 commit 00ef5de

1 file changed

Lines changed: 123 additions & 0 deletions

File tree

tests/unit/rl_utils_test.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,129 @@ def test_score_multiple_completions(self):
190190
self.assertEqual(scores[1], -2.0)
191191

192192

193+
class TestCheckNumbers(unittest.TestCase):
194+
"""Tests for utils_rl.check_numbers.
195+
196+
Covers two scenarios:
197+
1. Whether the regex can extract an answer from the completion.
198+
2. Whether the extracted value matches (or does not match) the reference answer.
199+
"""
200+
201+
def setUp(self):
202+
self.config = _make_config()
203+
204+
def _check(self, completions, answer):
205+
return utils_rl.check_numbers(
206+
prompts=None,
207+
completions=completions,
208+
answer=answer,
209+
tmvp_config=self.config,
210+
question=["test question"] * len(completions),
211+
)
212+
213+
# ---------------------------------------------------------------
214+
# Scenario 1: regex extraction succeeds / fails
215+
# ---------------------------------------------------------------
216+
217+
@pytest.mark.cpu_only
218+
def test_extraction_succeeds_full_format(self):
219+
"""Full <reasoning>…</reasoning><answer>…</answer> format allows extraction."""
220+
scores = self._check(
221+
completions=["<reasoning>40 + 2 = 42</reasoning><answer>42</answer>"],
222+
answer=["42"],
223+
)
224+
self.assertEqual(scores[0], 1.5)
225+
226+
@pytest.mark.cpu_only
227+
def test_extraction_fails_no_tags(self):
228+
"""Plain-text completion without any tags yields score 0 (cannot extract)."""
229+
scores = self._check(
230+
completions=["The answer is 42."],
231+
answer=["42"],
232+
)
233+
self.assertEqual(scores[0], 0)
234+
235+
@pytest.mark.cpu_only
236+
def test_extraction_fails_answer_tags_only(self):
237+
"""<answer> tag alone (no <reasoning> block) is not matched by the regex, score 0."""
238+
scores = self._check(
239+
completions=["<answer>42</answer>"],
240+
answer=["42"],
241+
)
242+
self.assertEqual(scores[0], 0)
243+
244+
@pytest.mark.cpu_only
245+
def test_extraction_fails_reasoning_tags_only(self):
246+
"""<reasoning> block with no <answer> tag cannot be extracted, score 0."""
247+
scores = self._check(
248+
completions=["<reasoning>The answer is 42.</reasoning>"],
249+
answer=["42"],
250+
)
251+
self.assertEqual(scores[0], 0)
252+
253+
@pytest.mark.cpu_only
254+
def test_extraction_batch_mixed(self):
255+
"""Batch with one extractable and one non-extractable completion."""
256+
scores = self._check(
257+
completions=[
258+
"<reasoning>thinking</reasoning><answer>7</answer>", # extractable
259+
"just 7", # not extractable
260+
],
261+
answer=["7", "7"],
262+
)
263+
self.assertEqual(scores[0], 1.5)
264+
self.assertEqual(scores[1], 0)
265+
266+
# ---------------------------------------------------------------
267+
# Scenario 2: extraction succeeds, value matches/mismatches the answer
268+
# ---------------------------------------------------------------
269+
270+
@pytest.mark.cpu_only
271+
def test_extracted_matches_integer_answer(self):
272+
"""Extracted integer equal to reference answer earns 1.5."""
273+
scores = self._check(
274+
completions=["<reasoning>simple</reasoning><answer>100</answer>"],
275+
answer=["100"],
276+
)
277+
self.assertEqual(scores[0], 1.5)
278+
279+
@pytest.mark.cpu_only
280+
def test_extracted_does_not_match_answer(self):
281+
"""Extracted number that differs from the reference answer earns 0.0."""
282+
scores = self._check(
283+
completions=["<reasoning>wrong path</reasoning><answer>99</answer>"],
284+
answer=["42"],
285+
)
286+
self.assertEqual(scores[0], 0.0)
287+
288+
@pytest.mark.cpu_only
289+
def test_extracted_matches_comma_formatted_number(self):
290+
"""Comma-formatted guess (e.g. '1,000') normalizes to match integer answer '1000'."""
291+
scores = self._check(
292+
completions=["<reasoning>cost calculation</reasoning><answer>1,000</answer>"],
293+
answer=["1000"],
294+
)
295+
self.assertEqual(scores[0], 1.5)
296+
297+
@pytest.mark.cpu_only
298+
def test_extracted_matches_with_currency_prefix(self):
299+
"""Leading '$' in extracted answer is normalized away before comparison."""
300+
scores = self._check(
301+
completions=["<reasoning>price is $16</reasoning><answer>$16</answer>"],
302+
answer=["16"],
303+
)
304+
self.assertEqual(scores[0], 1.5)
305+
306+
@pytest.mark.cpu_only
307+
def test_extracted_non_numeric_no_match(self):
308+
"""Non-numeric extraction that cannot be float-converted and does not math-verify returns 0."""
309+
scores = self._check(
310+
completions=["<reasoning>thinking</reasoning><answer>blue</answer>"],
311+
answer=["red"],
312+
)
313+
self.assertEqual(scores[0], 0.0)
314+
315+
193316
class TestExtractHashAnswer(unittest.TestCase):
194317
"""Tests for utils_rl.extract_hash_answer."""
195318

0 commit comments

Comments
 (0)