Skip to content

Commit 76d9f94

Browse files
Merge pull request #3349 from AI-Hypercomputer:hengtaoguo-test
PiperOrigin-RevId: 880983467
2 parents e2f6b0e + 749dd33 commit 76d9f94

5 files changed

Lines changed: 136 additions & 5 deletions

File tree

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jax
1717
jaxlib
1818
jaxtyping
1919
jsonlines
20+
math-verify
2021
ml-collections
2122
ml-goodput-measurement
2223
numpy

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ lxml>=6.0.2
120120
markdown-it-py>=4.0.0
121121
markdown>=3.10
122122
markupsafe>=3.0.3
123+
math-verify>=0.9.0
123124
matplotlib>=3.10.7
124125
mccabe>=0.7.0
125126
mdurl>=0.1.2

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def score_responses(tmvp_config, question, responses, answer):
100100
Tuple of (is_correct, is_partially_correct, has_correct_format)
101101
"""
102102
match_format = utils_rl.get_match_format_regex(tmvp_config)
103+
answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config)
103104

104105
if tmvp_config.debug.rl:
105106
max_logging.log("========================================")
@@ -113,10 +114,19 @@ def score_responses(tmvp_config, question, responses, answer):
113114
has_correct_format = False
114115

115116
for response in responses:
116-
# Extract numerical response
117-
extracted_response = guess.group(1) if (guess := match_format.search(response)) is not None else "-1000000"
117+
# Extract answer: prefer the full format match; fall back to the last
118+
# <answer>...</answer> tag if full format match is not found, so result
119+
# scoring is decoupled from format.
120+
full_match = match_format.search(response)
121+
if full_match is not None:
122+
extracted_response = full_match.group(1)
123+
else:
124+
# Find the *last* occurrence of the answer tag (most likely the final answer).
125+
fallback_matches = answer_fallback.findall(response)
126+
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
118127
if tmvp_config.debug.rl:
119-
max_logging.log(f"Evaluation extracted_response: {extracted_response}")
128+
used = "full format" if full_match is not None else "answer-tag fallback"
129+
max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}")
120130

121131
# Check exact correctness
122132
try:
@@ -146,8 +156,8 @@ def score_responses(tmvp_config, question, responses, answer):
146156
max_logging.log(f"Evaluation Exception: {e}")
147157
max_logging.log("SKIPPED")
148158

149-
# Check format correctness
150-
if match_format.search(response) is not None:
159+
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
160+
if full_match is not None:
151161
has_correct_format = True
152162

153163
# Early exit if all criteria are met

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def get_match_format_regex(tmvp_config):
118118
return match_format
119119

120120

121+
def get_answer_fallback_regex(tmvp_config):
122+
"""Returns a compiled regex that finds the *last* answer tag in a completion.
123+
124+
Used as a fallback when the full <reasoning>...</reasoning><answer>...</answer>
125+
format is incomplete (e.g. missing the closing reasoning tag). The result
126+
reward can still be computed independently from the format reward.
127+
"""
128+
return re.compile(
129+
rf"{re.escape(tmvp_config.solution_start_token)}(.+?){re.escape(tmvp_config.solution_end_token)}",
130+
flags=re.MULTILINE | re.DOTALL,
131+
)
132+
133+
121134
def match_format_exactly(prompts, completions, tmvp_config, **kargs):
122135
"""
123136
Give the model a reward of tmvp_config.reward_exact_format_match points if the format matches exactly.

tests/unit/rl_utils_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for RL result parsing and reward scoring (CPU-only)."""
16+
17+
import unittest
18+
import pytest
19+
from types import SimpleNamespace
20+
21+
evaluate_rl = pytest.importorskip(
22+
"maxtext.trainers.post_train.rl.evaluate_rl",
23+
reason="tunix (required by evaluate_rl) is not installed GPU",
24+
)
25+
26+
27+
def _make_config():
28+
"""Create a minimal config object with the parameters required by score_responses."""
29+
return SimpleNamespace(
30+
reasoning_start_token="<reasoning>",
31+
reasoning_end_token="</reasoning>",
32+
solution_start_token="<answer>",
33+
solution_end_token="</answer>",
34+
reward_exact_format_match=2.0,
35+
reward_partial_format_match=0.5,
36+
reward_white_space_format_match=1.5,
37+
reward_ratio_guess_to_answer_high=1.0,
38+
reward_ratio_guess_to_answer_low=0.5,
39+
penalty_incorrect_format=-0.5,
40+
penalty_incorrect_answer=-0.5,
41+
dataset_name="test",
42+
debug=SimpleNamespace(rl=False),
43+
)
44+
45+
46+
class TestScoreResponses(unittest.TestCase):
47+
"""Tests for evaluate_rl.score_responses parsing and correctness logic."""
48+
49+
def setUp(self):
50+
self.config = _make_config()
51+
52+
@pytest.mark.cpu_only
53+
def test_nested_tags(self):
54+
"""Response with nested reasoning tags still extracts the correct answer."""
55+
is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses(
56+
tmvp_config=self.config,
57+
question="What is 11/3?",
58+
responses=[
59+
"<reasoning>Need to use <reasoning> and </reasoning>, "
60+
"<answer> and </answer></reasoning><answer>11/3</answer>"
61+
],
62+
answer="11/3",
63+
)
64+
self.assertTrue(is_correct)
65+
self.assertTrue(is_partially_correct)
66+
self.assertTrue(has_correct_format)
67+
68+
@pytest.mark.cpu_only
69+
def test_with_extra_ending_tags(self):
70+
"""Answer with extra ending tags such as <end_of_turn>."""
71+
is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses(
72+
tmvp_config=self.config,
73+
question=(
74+
"James buys a new wardrobe. He buys 10 suits and 10 dress pants. "
75+
"He also buys 3 dress shirts per suit. The suits cost $750 each and "
76+
"the dress pants cost 1/5 that cost. The dress shirts were $60 each. "
77+
"How much did everything cost?"
78+
),
79+
responses=[
80+
"<reasoning>This is the sum of the cost of the suits, the pants, and the "
81+
"shirts: $7500 + $1500 + $1800 = $10800.\n\n</reasoning>\n"
82+
"<answer>10800</answer><end_of_turn>"
83+
],
84+
answer="10,800",
85+
)
86+
self.assertTrue(is_correct)
87+
self.assertTrue(is_partially_correct)
88+
self.assertTrue(has_correct_format)
89+
90+
@pytest.mark.cpu_only
91+
def test_with_incomplete_reasoning_tags(self):
92+
"""(1) Incomplete reasoning tags still extracts the correct answer."""
93+
"""(2) Currency symbols works with math_verify."""
94+
is_correct, is_partially_correct, has_correct_format = evaluate_rl.score_responses(
95+
tmvp_config=self.config,
96+
question="What is the price of the item?",
97+
responses=["<reasoning>The item costs $16.<answer>$16</answer>"],
98+
answer="16",
99+
)
100+
self.assertTrue(is_correct)
101+
self.assertTrue(is_partially_correct)
102+
self.assertFalse(has_correct_format)
103+
104+
105+
if __name__ == "__main__":
106+
unittest.main()

0 commit comments

Comments
 (0)