Skip to content

Commit e99b7ec

Browse files
committed
Move RL code from src/MaxText/rl/ to src/maxtext/trainers/post_train/rl/
1 parent 2b06b9c commit e99b7ec

14 files changed

Lines changed: 1260 additions & 1158 deletions

File tree

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ jobs:
9898
9999
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
100100
filename=$(basename "$notebook")
101+
if [[ "$filename" == "sft_qwen3_demo.ipynb" ]]; then
102+
echo "Skipping $filename"
103+
continue
104+
fi
101105
output_name="${filename%.ipynb}_output.ipynb"
102106
103107
echo "------------------------------------------------------"

codecov.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ ignore:
4040
- "src/maxtext/scratch_code"
4141
- "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation
4242
- "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft
43+
- "src/MaxText/rl" # code moved to src/maxtext/trainers/post_train/rl
4344

4445

4546
flags:

docs/tutorials/posttraining/rl.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucke
153153
Run the following command for GRPO:
154154

155155
```
156-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
156+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
157157
model_name=${MODEL} \
158158
tokenizer_path=${TOKENIZER} \
159159
load_parameters_path=${MAXTEXT_CKPT_PATH} \
@@ -176,7 +176,7 @@ The overview of what this run will do is as follows:
176176
Run the following command for GSPO:
177177

178178
```
179-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
179+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
180180
model_name=${MODEL} \
181181
tokenizer_path=${TOKENIZER} \
182182
load_parameters_path=${MAXTEXT_CKPT_PATH} \

docs/tutorials/posttraining/rl_on_multi_host.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ xpk workload create-pathways --workload $WORKLOAD \
208208
--tpu-type=$TPU_TYPE --num-slices=1 \
209209
--project=$PROJECT_ID --priority=high \
210210
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
211-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
211+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
212212
model_name=${MODEL} \
213213
tokenizer_path=${TOKENIZER} \
214214
load_parameters_path=${MAXTEXT_CKPT_PATH} \
@@ -225,7 +225,7 @@ xpk workload create-pathways --workload $WORKLOAD \
225225
--tpu-type=$TPU_TYPE --num-slices=1 \
226226
--project=$PROJECT_ID --priority=high \
227227
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
228-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
228+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
229229
model_name=${MODEL} \
230230
tokenizer_path=${TOKENIZER} \
231231
load_parameters_path=${MAXTEXT_CKPT_PATH} \

src/MaxText/rl/evaluate_rl.py

Lines changed: 10 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -12,222 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# pylint: disable=bare-except, consider-using-generator
16-
"""
17-
RL Evaluation Module.
18-
"""
19-
from tqdm.auto import tqdm
20-
from tunix.rl.rollout.base_rollout import RolloutConfig
15+
"""Shim for RL Evaluation in `src/maxtext/trainers/post_train/rl`."""
2116

22-
from MaxText.rl import utils_rl
23-
from maxtext.utils import max_logging
24-
25-
# ## Evaluate
26-
# We evaluate it in two ways:
27-
#
28-
# **Quantitative**
29-
#
30-
# * **Answer Accuracy**: percentage of samples for which the model predicts the
31-
# correct final numerical answer
32-
# * **Answer (Partial) Accuracy**: percentage of samples for which the model
33-
# predicts a final numerical answer such that the \`model answer / answer\`
34-
# ratio lies between 0.9 and 1.1.
35-
# * **Format Accuracy**: percentage of samples for which the model outputs the
36-
# correct format, i.e., reasoning between the reasoning special tokens, and the
37-
# final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens.
38-
#
39-
# **Qualitative**
40-
#
41-
# We'll also print outputs for a few given questions so that we can compare the generated output later.
42-
#
43-
# pylint: disable=broad-exception-caught
44-
45-
46-
def generate_responses(
47-
tmvp_config,
48-
prompts,
49-
rl_cluster,
50-
num_passes=1,
51-
):
52-
"""
53-
Generate responses for a batch of prompts across potentially multiple passes.
54-
55-
Args:
56-
tmvp_config: Configuration object
57-
prompts: List of prompts to generate responses for
58-
rl_cluster: Model cluster for generation
59-
num_passes: Number of generation passes
60-
61-
Returns:
62-
List of lists containing responses for each prompt across passes
63-
"""
64-
multiple_call_responses = [[] for _ in range(len(prompts))]
65-
eval_strategy = tmvp_config.generation_configs[tmvp_config.eval_sampling_strategy]
66-
67-
for p in range(num_passes):
68-
responses = rl_cluster.rollout.generate(
69-
prompts,
70-
rollout_config=RolloutConfig(
71-
max_tokens_to_generate=tmvp_config.max_target_length - tmvp_config.max_prefill_predict_length,
72-
temperature=eval_strategy["eval_temperature"],
73-
top_k=eval_strategy["eval_top_k"],
74-
top_p=eval_strategy["eval_top_p"],
75-
),
76-
)
77-
responses = responses.text
78-
79-
if tmvp_config.debug.rl:
80-
max_logging.log(f"Pass {p+1}/{num_passes}, responses: {responses}")
81-
82-
for idx, response in enumerate(responses):
83-
multiple_call_responses[idx].append(response)
84-
85-
return multiple_call_responses
86-
87-
88-
def score_responses(tmvp_config, question, responses, answer):
89-
"""
90-
Score a set of responses for a single question.
91-
92-
Args:
93-
tmvp_config: Configuration object
94-
question: The evaluation question
95-
responses: List of generated responses for this question
96-
answer: The correct answer
97-
98-
Returns:
99-
Tuple of (is_correct, is_partially_correct, has_correct_format)
100-
"""
101-
match_format = utils_rl.get_match_format_regex(tmvp_config)
102-
match_numbers = utils_rl.get_match_numbers_regex(tmvp_config)
103-
104-
if tmvp_config.debug.rl:
105-
max_logging.log("========================================")
106-
max_logging.log(f"Evaluation Question: {question}")
107-
max_logging.log(f"Evaluation Answer: {answer}")
108-
max_logging.log(f"Evaluation Responses: {responses}")
109-
max_logging.log("========================================")
17+
import importlib
11018

111-
is_correct = False
112-
is_partially_correct = False
113-
has_correct_format = False
114-
115-
for response in responses:
116-
# Extract numerical response
117-
extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000"
118-
119-
if tmvp_config.debug.rl:
120-
max_logging.log(f"Evaluation extracted_response: {extracted_response}")
121-
122-
# Check exact correctness
123-
try:
124-
# Remove ',' and '$' then convert to float
125-
val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip())
126-
val_answer = float(answer.replace(",", "").replace("$", "").strip())
127-
is_correct = val_extracted == val_answer
128-
129-
# Check partial correctness (within 10%)
130-
ratio = val_extracted / val_answer
131-
if 0.9 <= ratio <= 1.1:
132-
is_partially_correct = True
133-
134-
except Exception as e:
135-
if tmvp_config.debug.rl:
136-
max_logging.log(f"Evaluation Exception: {e}")
137-
max_logging.log("SKIPPED")
138-
139-
# Check format correctness
140-
if match_format.search(response) is not None:
141-
has_correct_format = True
142-
143-
# Early exit if all criteria are met
144-
if is_correct and is_partially_correct and has_correct_format:
145-
break
146-
147-
return is_correct, is_partially_correct, has_correct_format
148-
149-
150-
def evaluate(
151-
tmvp_config,
152-
dataset,
153-
rl_cluster,
154-
num_passes=1,
155-
corr_lst=False,
156-
make_lst=False,
157-
):
158-
"""
159-
Computes accuracy and percentage of outputs matching the format.
160-
161-
Args:
162-
tmvp_config: Configuration object
163-
dataset: The evaluation dataset
164-
rl_cluster: Model cluster for generation.
165-
num_passes: Number of generation passes
166-
corr_lst: If True, only include correct responses in the list
167-
make_lst: If True, return a list of (question, answer, responses)
168-
169-
Returns:
170-
Tuple of statistics and optionally the response list
171-
"""
172-
response_lst = []
173-
corr = 0
174-
partially_corr = 0
175-
corr_format = 0
176-
total = 0
177-
178-
for batch in tqdm(dataset):
179-
answers = batch["answer"]
180-
questions = batch["question"]
181-
prompts = batch["prompts"]
182-
183-
# Generate responses for all prompts in the batch
184-
multiple_call_responses = generate_responses(
185-
tmvp_config=tmvp_config,
186-
prompts=prompts,
187-
rl_cluster=rl_cluster,
188-
num_passes=num_passes,
189-
)
190-
191-
# Score each question-answer pair
192-
for question, responses, answer in zip(questions, multiple_call_responses, answers):
193-
is_correct, is_partially_correct, has_correct_format = score_responses(
194-
tmvp_config=tmvp_config,
195-
question=question,
196-
responses=responses,
197-
answer=answer,
198-
)
199-
200-
# Update counters
201-
if is_correct:
202-
corr += 1
203-
if corr_lst and make_lst:
204-
response_lst.append((question, answer, responses))
205-
else:
206-
if not corr_lst and make_lst:
207-
response_lst.append((question, answer, responses))
208-
209-
if is_partially_correct:
210-
partially_corr += 1
211-
212-
if has_correct_format:
213-
corr_format += 1
214-
215-
total += 1
19+
from maxtext.utils import max_logging
21620

217-
# Print progress every 10 items
218-
if total % 10 == 0:
219-
max_logging.log(
220-
f"===> {corr=}, {total=}, {corr / total * 100=}, "
221-
f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
222-
)
21+
OLD_MODULE_PATH = "MaxText.rl.evaluate_rl"
22+
NEW_MODULE_PATH = "maxtext.trainers.post_train.rl.evaluate_rl"
22323

224-
# Prepare return values
225-
to_return = (
226-
corr,
227-
total,
228-
corr / total * 100,
229-
partially_corr / total * 100,
230-
corr_format / total * 100,
231-
)
24+
max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n")
25+
_new_module = importlib.import_module(NEW_MODULE_PATH)
23226

233-
return to_return, response_lst
27+
evaluate = _new_module.evaluate
28+
generate_responses = _new_module.generate_responses
29+
score_responses = _new_module.score_responses

0 commit comments

Comments
 (0)