|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -# pylint: disable=bare-except, consider-using-generator |
16 | | -""" |
17 | | -RL Evaluation Module. |
18 | | -""" |
19 | | -from tqdm.auto import tqdm |
20 | | -from tunix.rl.rollout.base_rollout import RolloutConfig |
| 15 | +"""Shim for RL Evaluation in `src/maxtext/trainers/post_train/rl`.""" |
21 | 16 |
|
22 | | -from MaxText.rl import utils_rl |
23 | | -from maxtext.utils import max_logging |
24 | | - |
25 | | -# ## Evaluate |
26 | | -# We evaluate it in two ways: |
27 | | -# |
28 | | -# **Quantitative** |
29 | | -# |
30 | | -# * **Answer Accuracy**: percentage of samples for which the model predicts the |
31 | | -# correct final numerical answer |
32 | | -# * **Answer (Partial) Accuracy**: percentage of samples for which the model |
33 | | -# predicts a final numerical answer such that the \`model answer / answer\` |
34 | | -# ratio lies between 0.9 and 1.1. |
35 | | -# * **Format Accuracy**: percentage of samples for which the model outputs the |
36 | | -# correct format, i.e., reasoning between the reasoning special tokens, and the |
37 | | -# final answer between the \`\<start\_answer\>\`, \`\<end\_answer\>\` tokens. |
38 | | -# |
39 | | -# **Qualitative** |
40 | | -# |
41 | | -# We'll also print outputs for a few given questions so that we can compare the generated output later. |
42 | | -# |
43 | | -# pylint: disable=broad-exception-caught |
44 | | - |
45 | | - |
46 | | -def generate_responses( |
47 | | - tmvp_config, |
48 | | - prompts, |
49 | | - rl_cluster, |
50 | | - num_passes=1, |
51 | | -): |
52 | | - """ |
53 | | - Generate responses for a batch of prompts across potentially multiple passes. |
54 | | -
|
55 | | - Args: |
56 | | - tmvp_config: Configuration object |
57 | | - prompts: List of prompts to generate responses for |
58 | | - rl_cluster: Model cluster for generation |
59 | | - num_passes: Number of generation passes |
60 | | -
|
61 | | - Returns: |
62 | | - List of lists containing responses for each prompt across passes |
63 | | - """ |
64 | | - multiple_call_responses = [[] for _ in range(len(prompts))] |
65 | | - eval_strategy = tmvp_config.generation_configs[tmvp_config.eval_sampling_strategy] |
66 | | - |
67 | | - for p in range(num_passes): |
68 | | - responses = rl_cluster.rollout.generate( |
69 | | - prompts, |
70 | | - rollout_config=RolloutConfig( |
71 | | - max_tokens_to_generate=tmvp_config.max_target_length - tmvp_config.max_prefill_predict_length, |
72 | | - temperature=eval_strategy["eval_temperature"], |
73 | | - top_k=eval_strategy["eval_top_k"], |
74 | | - top_p=eval_strategy["eval_top_p"], |
75 | | - ), |
76 | | - ) |
77 | | - responses = responses.text |
78 | | - |
79 | | - if tmvp_config.debug.rl: |
80 | | - max_logging.log(f"Pass {p+1}/{num_passes}, responses: {responses}") |
81 | | - |
82 | | - for idx, response in enumerate(responses): |
83 | | - multiple_call_responses[idx].append(response) |
84 | | - |
85 | | - return multiple_call_responses |
86 | | - |
87 | | - |
88 | | -def score_responses(tmvp_config, question, responses, answer): |
89 | | - """ |
90 | | - Score a set of responses for a single question. |
91 | | -
|
92 | | - Args: |
93 | | - tmvp_config: Configuration object |
94 | | - question: The evaluation question |
95 | | - responses: List of generated responses for this question |
96 | | - answer: The correct answer |
97 | | -
|
98 | | - Returns: |
99 | | - Tuple of (is_correct, is_partially_correct, has_correct_format) |
100 | | - """ |
101 | | - match_format = utils_rl.get_match_format_regex(tmvp_config) |
102 | | - match_numbers = utils_rl.get_match_numbers_regex(tmvp_config) |
103 | | - |
104 | | - if tmvp_config.debug.rl: |
105 | | - max_logging.log("========================================") |
106 | | - max_logging.log(f"Evaluation Question: {question}") |
107 | | - max_logging.log(f"Evaluation Answer: {answer}") |
108 | | - max_logging.log(f"Evaluation Responses: {responses}") |
109 | | - max_logging.log("========================================") |
| 17 | +import importlib |
110 | 18 |
|
111 | | - is_correct = False |
112 | | - is_partially_correct = False |
113 | | - has_correct_format = False |
114 | | - |
115 | | - for response in responses: |
116 | | - # Extract numerical response |
117 | | - extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000" |
118 | | - |
119 | | - if tmvp_config.debug.rl: |
120 | | - max_logging.log(f"Evaluation extracted_response: {extracted_response}") |
121 | | - |
122 | | - # Check exact correctness |
123 | | - try: |
124 | | - # Remove ',' and '$' then convert to float |
125 | | - val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip()) |
126 | | - val_answer = float(answer.replace(",", "").replace("$", "").strip()) |
127 | | - is_correct = val_extracted == val_answer |
128 | | - |
129 | | - # Check partial correctness (within 10%) |
130 | | - ratio = val_extracted / val_answer |
131 | | - if 0.9 <= ratio <= 1.1: |
132 | | - is_partially_correct = True |
133 | | - |
134 | | - except Exception as e: |
135 | | - if tmvp_config.debug.rl: |
136 | | - max_logging.log(f"Evaluation Exception: {e}") |
137 | | - max_logging.log("SKIPPED") |
138 | | - |
139 | | - # Check format correctness |
140 | | - if match_format.search(response) is not None: |
141 | | - has_correct_format = True |
142 | | - |
143 | | - # Early exit if all criteria are met |
144 | | - if is_correct and is_partially_correct and has_correct_format: |
145 | | - break |
146 | | - |
147 | | - return is_correct, is_partially_correct, has_correct_format |
148 | | - |
149 | | - |
150 | | -def evaluate( |
151 | | - tmvp_config, |
152 | | - dataset, |
153 | | - rl_cluster, |
154 | | - num_passes=1, |
155 | | - corr_lst=False, |
156 | | - make_lst=False, |
157 | | -): |
158 | | - """ |
159 | | - Computes accuracy and percentage of outputs matching the format. |
160 | | -
|
161 | | - Args: |
162 | | - tmvp_config: Configuration object |
163 | | - dataset: The evaluation dataset |
164 | | - rl_cluster: Model cluster for generation. |
165 | | - num_passes: Number of generation passes |
166 | | - corr_lst: If True, only include correct responses in the list |
167 | | - make_lst: If True, return a list of (question, answer, responses) |
168 | | -
|
169 | | - Returns: |
170 | | - Tuple of statistics and optionally the response list |
171 | | - """ |
172 | | - response_lst = [] |
173 | | - corr = 0 |
174 | | - partially_corr = 0 |
175 | | - corr_format = 0 |
176 | | - total = 0 |
177 | | - |
178 | | - for batch in tqdm(dataset): |
179 | | - answers = batch["answer"] |
180 | | - questions = batch["question"] |
181 | | - prompts = batch["prompts"] |
182 | | - |
183 | | - # Generate responses for all prompts in the batch |
184 | | - multiple_call_responses = generate_responses( |
185 | | - tmvp_config=tmvp_config, |
186 | | - prompts=prompts, |
187 | | - rl_cluster=rl_cluster, |
188 | | - num_passes=num_passes, |
189 | | - ) |
190 | | - |
191 | | - # Score each question-answer pair |
192 | | - for question, responses, answer in zip(questions, multiple_call_responses, answers): |
193 | | - is_correct, is_partially_correct, has_correct_format = score_responses( |
194 | | - tmvp_config=tmvp_config, |
195 | | - question=question, |
196 | | - responses=responses, |
197 | | - answer=answer, |
198 | | - ) |
199 | | - |
200 | | - # Update counters |
201 | | - if is_correct: |
202 | | - corr += 1 |
203 | | - if corr_lst and make_lst: |
204 | | - response_lst.append((question, answer, responses)) |
205 | | - else: |
206 | | - if not corr_lst and make_lst: |
207 | | - response_lst.append((question, answer, responses)) |
208 | | - |
209 | | - if is_partially_correct: |
210 | | - partially_corr += 1 |
211 | | - |
212 | | - if has_correct_format: |
213 | | - corr_format += 1 |
214 | | - |
215 | | - total += 1 |
| 19 | +from maxtext.utils import max_logging |
216 | 20 |
|
217 | | - # Print progress every 10 items |
218 | | - if total % 10 == 0: |
219 | | - max_logging.log( |
220 | | - f"===> {corr=}, {total=}, {corr / total * 100=}, " |
221 | | - f"{partially_corr / total * 100=}, {corr_format / total * 100=}" |
222 | | - ) |
| 21 | +OLD_MODULE_PATH = "MaxText.rl.evaluate_rl" |
| 22 | +NEW_MODULE_PATH = "maxtext.trainers.post_train.rl.evaluate_rl" |
223 | 23 |
|
224 | | - # Prepare return values |
225 | | - to_return = ( |
226 | | - corr, |
227 | | - total, |
228 | | - corr / total * 100, |
229 | | - partially_corr / total * 100, |
230 | | - corr_format / total * 100, |
231 | | - ) |
| 24 | +max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n") |
| 25 | +_new_module = importlib.import_module(NEW_MODULE_PATH) |
232 | 26 |
|
233 | | - return to_return, response_lst |
| 27 | +evaluate = _new_module.evaluate |
| 28 | +generate_responses = _new_module.generate_responses |
| 29 | +score_responses = _new_module.score_responses |
0 commit comments