Skip to content

Commit c9ffb30

Browse files
SurbhiJainUSCGoogle-ML-Automation
authored andcommitted
Fix chat template pathing in RL config and template loading
PiperOrigin-RevId: 886940444
1 parent 785ac61 commit c9ffb30

3 files changed

Lines changed: 20 additions & 1 deletion

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ reasoning_start_token: '<reasoning>'
177177
reasoning_end_token: '</reasoning>'
178178
solution_start_token: '<answer>'
179179
solution_end_token: '</answer>'
180-
chat_template_path: 'src/maxtext/examples/chat_templates/gsm8k_rl.json'
180+
chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json'
181181
skip_jax_distributed_system: True
182182

183183
# # TODO(@mazumdera): fix this

src/maxtext/input_pipeline/instruction_data_processing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
def load_template_from_file(template_path):
2525
"""Loads a template from a file."""
2626
template_config = None
27+
current_dir = os.path.dirname(os.path.abspath(__file__))
28+
repo_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
29+
template_path = os.path.join(repo_root, template_path)
2730
if os.path.isfile(template_path) and template_path.endswith(".json"):
2831
with open(template_path, encoding="utf-8") as f:
2932
template_config = json.load(f)

tests/unit/instruction_data_processing_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@
2121

2222
class InstructionDataProcessingTest(unittest.TestCase):
2323

24+
def test_load_template_from_file(self):
25+
template_config = instruction_data_processing.load_template_from_file("maxtext/examples/chat_templates/gsm8k_rl.json")
26+
self.assertEqual(
27+
template_config,
28+
{
29+
"SYSTEM_PROMPT": (
30+
"You are given a problem. Think about the problem and provide"
31+
" your reasoning. Place it between {reasoning_start_token} and"
32+
" {reasoning_end_token}. Then, provide the final answer (i.e.,"
33+
" just one numerical value) between {solution_start_token} and"
34+
" {solution_end_token}."
35+
),
36+
"TEMPLATE": ("<start_of_turn>user\n{system_prompt}\n\n{question}<end_of_turn>\n<start_of_turn>model"),
37+
},
38+
)
39+
2440
def test_map_qa_data_to_conversation_with_prompt_completion_template(self):
2541
template_config = {
2642
"PROMPT_TEMPLATE": "This is a question: {question}",

0 commit comments

Comments
 (0)