Skip to content

Commit 05c5083

Browse files
Merge pull request #3365 from AI-Hypercomputer:rl-utils-tests
PiperOrigin-RevId: 881558474
2 parents f2d2ec8 + f11d00d commit 05c5083

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

tests/unit/rl_utils_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
reason="tunix (required by evaluate_rl) is not installed GPU",
2424
)
2525

26+
utils_rl = pytest.importorskip(
27+
"maxtext.trainers.post_train.rl.utils_rl",
28+
reason="tunix (required by utils_rl) is not installed GPU",
29+
)
30+
2631

2732
def _make_config():
2833
"""Create a minimal config object with the parameters required by score_responses."""
@@ -102,5 +107,22 @@ def test_with_incomplete_reasoning_tags(self):
102107
self.assertFalse(has_correct_format)
103108

104109

110+
class TestExtractHashAnswer(unittest.TestCase):
111+
"""Tests for utils_rl.extract_hash_answer."""
112+
113+
@pytest.mark.cpu_only
114+
def test_with_hash(self):
115+
"""Test extraction when #### is present."""
116+
self.assertEqual(utils_rl.extract_hash_answer("The answer is #### 42"), "42")
117+
self.assertEqual(utils_rl.extract_hash_answer("Some reasoning #### 123.45 "), "123.45")
118+
self.assertEqual(utils_rl.extract_hash_answer("####"), "")
119+
120+
@pytest.mark.cpu_only
121+
def test_without_hash(self):
122+
"""Test extraction when #### is not present."""
123+
self.assertIsNone(utils_rl.extract_hash_answer("The answer is 42"))
124+
self.assertIsNone(utils_rl.extract_hash_answer(""))
125+
126+
105127
if __name__ == "__main__":
106128
unittest.main()

0 commit comments

Comments
 (0)