Skip to content

Commit 5c55054

Browse files
committed
Use pydantic natively in RL ; nest GRPO ; Fix rl_llama3_demo.ipynb to demo rl
1 parent c32eb92 commit 5c55054

8 files changed

Lines changed: 118 additions & 87 deletions

File tree

src/MaxText/configs/rl.yml

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,27 @@ rollout_tensor_parallelism: -1
3131
# ====== Reproducibility ======
3232
data_shuffle_seed: 42
3333

34-
# ====== GRPO ======
35-
36-
# The number of times the policy generates multiple responses for a given prompt
37-
# within a single training step. This corresponds to `G` in Algorithm 1 in the
38-
# paper. The "group" in GRPO comes from here.
39-
num_generations: 2
40-
41-
# === other GRPO configs ===
42-
# The number of iterations per batch (𝜇 in GRPO algo 1).
43-
num_iterations: 1
44-
45-
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
46-
# Important to keep a high enough value for this, otherwise, the KL divergence
47-
# can increase unchecked.
48-
grpo_beta: 0.08
49-
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
50-
# stable updates.
51-
grpo_epsilon: 0.2
52-
loss_algo: 'grpo' # grpo or gspo-token
34+
# ====== RL ======
35+
# This config includes RL algorithm variations such as grpo or gspo-token
36+
rl:
37+
# ====== GRPO/GSPO-Token ======
38+
# The number of times the policy generates multiple responses for a given prompt
39+
# within a single training step. This corresponds to `G` in Algorithm 1 in the
40+
# paper. The "group" in GRPO comes from here.
41+
num_generations: 2
42+
43+
# === other GRPO configs ===
44+
# The number of iterations per batch (𝜇 in GRPO algo 1).
45+
num_iterations: 1
46+
47+
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
48+
# Important to keep a high enough value for this, otherwise, the KL divergence
49+
# can increase unchecked.
50+
grpo_beta: 0.08
51+
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
52+
# stable updates.
53+
grpo_epsilon: 0.2
54+
loss_algo: 'grpo' # grpo or gspo-token
5355

5456

5557
# ====== Models ======

src/MaxText/configs/types.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,14 +1398,14 @@ class VLLM(BaseModel):
13981398
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
13991399

14001400

1401-
class GRPO(BaseModel):
1402-
"""Configuration for Group Relative Policy Optimization (GRPO)."""
1401+
class RL(BaseModel):
1402+
"""Configuration for RL algorithms like Group Relative Policy Optimization (GRPO) among others."""
14031403

14041404
num_generations: int = Field(2, description="Number of responses to generate per prompt (G in GRPO paper).")
14051405
num_iterations: int = Field(1, description="Number of iterations per batch (μ in GRPO paper).")
14061406
grpo_beta: float = Field(0.08, description="Coefficient for the KL divergence penalty (β).")
14071407
grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.")
1408-
loss_algo: str = Field("grpo", description="Loss algorithm, e.g., 'grpo' or 'gspo-token'.")
1408+
loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.")
14091409

14101410

14111411
class RLDataset(BaseModel):
@@ -1639,7 +1639,6 @@ class MaxTextConfig(
16391639
# Reinforcement Learning
16401640
RLHardware,
16411641
VLLM,
1642-
GRPO,
16431642
RLDataset,
16441643
RLEvaluation,
16451644
Reward,
@@ -1689,6 +1688,9 @@ class MaxTextConfig(
16891688
"""
16901689

16911690
debug: Debug = Field(default_factory=Debug, description="Configuration for debugging options.")
1691+
rl: RL = Field(
1692+
default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO)."
1693+
)
16921694
model_config = ConfigDict(extra="forbid", protected_namespaces=())
16931695

16941696
@model_validator(mode="before")
@@ -2134,7 +2136,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21342136
raise ValueError("`eval_steps` must be > 0 when `generate_padding_batch_eval` is True.")
21352137
if self.dataset_type == "hf" and self.num_epoch != 1:
21362138
raise ValueError("HuggingFace pipeline only supports num_epoch=1.")
2137-
if self.loss_algo == "grpo":
2139+
if self.rl.loss_algo == "grpo":
21382140
self.use_grpo = True
21392141
else:
21402142
self.use_grpo = False

src/MaxText/examples/rl_llama3_demo.ipynb

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@
107107
"from pathlib import Path\n",
108108
"import MaxText\n",
109109
"from huggingface_hub import login\n",
110-
"import jax\n",
111110
"\n",
112111
"# Set up paths (adjust if needed)\n",
113112
"MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
@@ -127,26 +126,21 @@
127126
" raise RuntimeError(\"OUTPUT_DIRECTORY is not set\")\n",
128127
" \n",
129128
"os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
129+
"if \"MAXTEXT_PKG_DIR\" not in os.environ:\n",
130+
" os.environ[\"MAXTEXT_PKG_DIR\"] = MAXTEXT_REPO_ROOT\n",
130131
"\n",
131132
"if HF_TOKEN:\n",
132133
" login(token=HF_TOKEN)\n",
133134
" print(\"Authenticated with Hugging Face\")\n",
134135
"else:\n",
135136
" print(\"Authentication failed: Hugging Face token not set\")\n",
136137
"\n",
137-
"# Optional: Override training parameters\n",
138-
"LEARNING_RATE = 3e-6\n",
139-
"NUM_GENERATIONS = 2\n",
140-
"GRPO_BETA = 0.08\n",
141-
"GRPO_EPSILON = 0.2\n",
142-
"CHIPS_PER_VM = 1\n",
143138
"\n",
144139
"print(f\"📁 MaxText Home: {MAXTEXT_REPO_ROOT}\")\n",
145140
"print(f\"🤖 Model: {MODEL_NAME}\")\n",
146141
"print(f\"📦 Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n",
147142
"print(f\"💾 Output: {OUTPUT_DIRECTORY}\")\n",
148143
"print(f\"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n",
149-
"print(f\"📊 Steps: {STEPS}\")\n",
150144
"print(f\"Loss Algorithm : {LOSS_ALGO}\")"
151145
]
152146
},
@@ -178,10 +172,10 @@
178172
"outputs": [],
179173
"source": [
180174
"# Build configuration for GRPO training\n",
181-
"config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n",
175+
"config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs\", \"rl.yml\")\n",
182176
"\n",
183177
"# Verify chat template exists\n",
184-
"if not os.path.exists(CHAT_TEMPLATE_PATH)):\n",
178+
"if not os.path.exists(CHAT_TEMPLATE_PATH):\n",
185179
" raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
186180
"\n",
187181
"# Build argv list for pyconfig.initialize()\n",
@@ -195,23 +189,26 @@
195189
" f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n",
196190
" f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
197191
" f\"hf_access_token={HF_TOKEN}\",\n",
198-
" f\"learning_rate={LEARNING_RATE}\",\n",
199-
" f\"num_generations={NUM_GENERATIONS}\",\n",
200-
" f\"grpo_beta={GRPO_BETA}\",\n",
201-
" f\"grpo_epsilon={GRPO_EPSILON}\",\n",
202-
" f\"chips_per_vm={CHIPS_PER_VM}\",\n",
203-
" f\"loss_algo={LOSS_ALGO}\",\n",
192+
" f\"debug.rl=False\",\n",
193+
" f\"rl.loss_algo={LOSS_ALGO}\",\n",
204194
" \"use_pathways=False\"\n",
205195
"]\n",
206196
"\n",
207197
"# Initialize configuration\n",
208198
"print(f\"🔧 Initializing configuration from: {config_file}\")\n",
209-
"config = pyconfig.initialize(config_argv)\n",
199+
"trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n",
200+
"\n",
201+
"rl_train_steps = int(\n",
202+
" trainer_config.num_batches\n",
203+
" * trainer_config.rl.num_iterations\n",
204+
" * trainer_config.train_fraction\n",
205+
" * trainer_config.num_epoch\n",
206+
" )\n",
210207
"\n",
211208
"print(\"\\n✅ Configuration initialized successfully\")\n",
212-
"print(f\"📊 Training steps: {config.steps}\")\n",
213-
"print(f\"📁 Output directory: {config.base_output_directory}\")\n",
214-
"print(f\"🤖 Model: {config.model_name}\")"
209+
"print(f\"📁 Output directory: {trainer_config.base_output_directory}\")\n",
210+
"print(f\"🤖 Model: {trainer_config.model_name}\")\n",
211+
"print(f\"📊 RL Train Steps: {rl_train_steps}\")"
215212
]
216213
},
217214
{
@@ -224,16 +221,16 @@
224221
"print(\"\\n\" + \"=\"*80)\n",
225222
"print(\"🚀 Starting Training...\")\n",
226223
"print(\"=\"*80)\n",
227-
"print(1)\n",
228224
"try:\n",
229225
" # Call the rl_train function (it handles everything internally)\n",
230-
" rl_train(config)\n",
226+
" rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)\n",
231227
" \n",
232228
" print(\"\\n\" + \"=\"*80)\n",
233229
" print(\"✅ Training Completed Successfully!\")\n",
230+
" print(f\"✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\")\n",
234231
" print(\"=\"*80)\n",
235-
" print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n",
236-
" print(f\"📊 TensorBoard logs: {config.tensorboard_dir}\")\n",
232+
" print(f\"📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\")\n",
233+
" print(f\"📊 TensorBoard logs: {trainer_config.tensorboard_dir}\")\n",
237234
" print(f\"🎯 Model ready for inference!\")\n",
238235
" \n",
239236
"except Exception as e:\n",
@@ -264,7 +261,7 @@
264261
],
265262
"metadata": {
266263
"kernelspec": {
267-
"display_name": "Python 3",
264+
"display_name": "maxtext_venv",
268265
"language": "python",
269266
"name": "python3"
270267
},
@@ -278,7 +275,7 @@
278275
"name": "python",
279276
"nbconvert_exporter": "python",
280277
"pygments_lexer": "ipython3",
281-
"version": "3.8.5"
278+
"version": "3.12.11"
282279
}
283280
},
284281
"nbformat": 4,

src/MaxText/max_logging.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Logging utilities."""
16+
import logging as std_logging
1617
from absl import logging
1718

1819

@@ -40,3 +41,20 @@ def warning(user_str):
4041
def error(user_str):
4142
"""Logs a message at the ERROR level."""
4243
logging.error(user_str, stacklevel=2)
44+
45+
46+
# Define filter at module level to avoid pickling issues and ensure visibility
47+
class NoisyLogFilter(std_logging.Filter):
48+
"""
49+
Class for defining log patterns to filter out
50+
"""
51+
52+
def filter(self, record):
53+
# Get the message; check both the raw msg and formatted message
54+
msg = record.getMessage()
55+
# Suppress "Type mismatch" warnings from tunix/generate/utils.py
56+
if "Type mismatch on" in msg:
57+
return False
58+
if "No mapping for flat state" in msg:
59+
return False
60+
return True

src/MaxText/rl/evaluate_rl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def generate_responses(
7676
)
7777
responses = responses.text
7878

79-
if tmvp_config.debug["rl"]:
79+
if tmvp_config.debug.rl:
8080
max_logging.log(f"Pass {p+1}/{num_passes}, responses: {responses}")
8181

8282
for idx, response in enumerate(responses):
@@ -101,7 +101,7 @@ def score_responses(tmvp_config, question, responses, answer):
101101
match_format = utils_rl.get_match_format_regex(tmvp_config)
102102
match_numbers = utils_rl.get_match_numbers_regex(tmvp_config)
103103

104-
if tmvp_config.debug["rl"]:
104+
if tmvp_config.debug.rl:
105105
max_logging.log("========================================")
106106
max_logging.log(f"Evaluation Question: {question}")
107107
max_logging.log(f"Evaluation Answer: {answer}")
@@ -116,7 +116,7 @@ def score_responses(tmvp_config, question, responses, answer):
116116
# Extract numerical response
117117
extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000"
118118

119-
if tmvp_config.debug["rl"]:
119+
if tmvp_config.debug.rl:
120120
max_logging.log(f"Evaluation extracted_response: {extracted_response}")
121121

122122
# Check exact correctness
@@ -132,7 +132,7 @@ def score_responses(tmvp_config, question, responses, answer):
132132
is_partially_correct = True
133133

134134
except Exception as e:
135-
if tmvp_config.debug["rl"]:
135+
if tmvp_config.debug.rl:
136136
max_logging.log(f"Evaluation Exception: {e}")
137137
max_logging.log("SKIPPED")
138138

0 commit comments

Comments
 (0)