Skip to content

Commit 6a9f060

Browse files
Merge pull request #3485 from AI-Hypercomputer:anisha-fix-reward-fallback
PiperOrigin-RevId: 888297733
2 parents dc29039 + 727c9e1 commit 6a9f060

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,16 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
229229
value.
230230
"""
231231
match_format = get_match_format_regex(tmvp_config)
232-
extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions]
232+
answer_fallback = get_answer_fallback_regex(tmvp_config)
233+
234+
extracted_responses = []
235+
for c in completions:
236+
full_match = match_format.search(c)
237+
if full_match is not None:
238+
extracted_responses.append(full_match.group(1))
239+
else:
240+
fallback_matches = answer_fallback.findall(c)
241+
extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None)
233242

234243
scores = []
235244
for guess, true_answer in zip(extracted_responses, answer):
@@ -408,7 +417,16 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
408417

409418
# Extract full answer content from solution tags (not just first number)
410419
match_format = get_match_format_regex(tmvp_config)
411-
extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions]
420+
answer_fallback = get_answer_fallback_regex(tmvp_config)
421+
422+
extracted_responses = []
423+
for c in completions:
424+
full_match = match_format.search(c)
425+
if full_match is not None:
426+
extracted_responses.append(full_match.group(1))
427+
else:
428+
fallback_matches = answer_fallback.findall(c)
429+
extracted_responses.append(fallback_matches[-1].strip() if fallback_matches else None)
412430

413431
scores = []
414432
if tmvp_config.debug.rl:

tests/post_training/unit/rl_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,12 @@ def test_extraction_fails_no_tags(self):
236236

237237
@pytest.mark.cpu_only
238238
def test_extraction_fails_answer_tags_only(self):
239-
"""<answer> tag alone (no <reasoning> block) is not matched by the regex, score 0."""
239+
"""<answer> tag alone (no <reasoning> block) is matched by the regex as a fallback, score 1.5."""
240240
scores = self._check(
241241
completions=["<answer>42</answer>"],
242242
answer=["42"],
243243
)
244-
self.assertEqual(scores[0], 0)
244+
self.assertEqual(scores[0], 1.5)
245245

246246
@pytest.mark.cpu_only
247247
def test_extraction_fails_reasoning_tags_only(self):

0 commit comments

Comments
 (0)