Skip to content

Commit 5f1717b

Browse files
Merge pull request #3180 from AI-Hypercomputer:anisha-rl-refactor
PiperOrigin-RevId: 872606777
2 parents c6f3bc2 + e99b7ec commit 5f1717b

15 files changed

Lines changed: 165 additions & 63 deletions

File tree

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ jobs:
9898
9999
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
100100
filename=$(basename "$notebook")
101+
if [[ "$filename" == "sft_qwen3_demo.ipynb" ]]; then
102+
echo "Skipping $filename"
103+
continue
104+
fi
101105
output_name="${filename%.ipynb}_output.ipynb"
102106
103107
echo "------------------------------------------------------"

codecov.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ ignore:
4040
- "src/maxtext/scratch_code"
4141
- "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation
4242
- "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft
43+
- "src/MaxText/rl" # code moved to src/maxtext/trainers/post_train/rl
4344

4445

4546
flags:

docs/tutorials/posttraining/rl.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucke
161161
Run the following command for GRPO:
162162

163163
```
164-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
164+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
165165
model_name=${MODEL} \
166166
tokenizer_path=${TOKENIZER} \
167167
load_parameters_path=${MAXTEXT_CKPT_PATH} \
@@ -185,7 +185,7 @@ The overview of what this run will do is as follows:
185185
Run the following command for GSPO:
186186

187187
```
188-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
188+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
189189
model_name=${MODEL} \
190190
tokenizer_path=${TOKENIZER} \
191191
load_parameters_path=${MAXTEXT_CKPT_PATH} \

docs/tutorials/posttraining/rl_on_multi_host.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ xpk workload create-pathways --workload $WORKLOAD \
208208
--tpu-type=$TPU_TYPE --num-slices=1 \
209209
--project=$PROJECT_ID --priority=high \
210210
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
211-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
211+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
212212
model_name=${MODEL} \
213213
tokenizer_path=${TOKENIZER} \
214214
load_parameters_path=${MAXTEXT_CKPT_PATH} \
@@ -225,7 +225,7 @@ xpk workload create-pathways --workload $WORKLOAD \
225225
--tpu-type=$TPU_TYPE --num-slices=1 \
226226
--project=$PROJECT_ID --priority=high \
227227
--command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \
228-
python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
228+
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml \
229229
model_name=${MODEL} \
230230
tokenizer_path=${TOKENIZER} \
231231
load_parameters_path=${MAXTEXT_CKPT_PATH} \

src/maxtext/examples/rl_llama3_demo.ipynb

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
},
9595
{
9696
"cell_type": "code",
97-
"execution_count": 2,
97+
"execution_count": null,
9898
"metadata": {},
9999
"outputs": [],
100100
"source": [
@@ -148,13 +148,15 @@
148148
"from pathlib import Path\n",
149149
"import MaxText\n",
150150
"from huggingface_hub import login\n",
151+
"from etils import epath\n",
151152
"import jax\n",
152153
"\n",
153-
"from MaxText import max_utils\n",
154-
"from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n",
154+
"from maxtext.trainers.post_train.rl.train_rl import rl_train, setup_configs_and_devices\n",
155155
"\n",
156156
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
157157
"os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
158+
"# Suppress vLLM logging with a severity level below ERROR\n",
159+
"os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
158160
"\n",
159161
"MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n",
160162
"MAXTEXT_REPO_ROOT = os.sep.join([\"maxtext\" if p == \"MaxText\" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])\n",
@@ -243,7 +245,7 @@
243245
"metadata": {},
244246
"outputs": [],
245247
"source": [
246-
"if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
248+
"if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n",
247249
" # install torch for the conversion script\n",
248250
" !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
249251
"\n",
@@ -256,8 +258,8 @@
256258
" scan_layers=true \\\n",
257259
" skip_jax_distributed_system=True\n",
258260
"\n",
259-
"if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
260-
" raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
261+
" if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n",
262+
" raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
261263
]
262264
},
263265
{
@@ -276,7 +278,7 @@
276278
"# Load configuration for RL training\n",
277279
"config_argv = [\n",
278280
" \"\",\n",
279-
" f\"{MAXTEXT_PKG_DIR}/configs/rl.yml\",\n",
281+
" f\"{MAXTEXT_PKG_DIR}/configs/post_train/rl.yml\",\n",
280282
" f\"model_name={MODEL_NAME}\",\n",
281283
" f\"tokenizer_path={TOKENIZER_PATH}\",\n",
282284
" f\"run_name={RUN_NAME}\",\n",
@@ -344,13 +346,13 @@
344346
"\n",
345347
"- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html\n",
346348
"- **Configuration**: See `src/maxtext/configs/rl.yml` for all available options\n",
347-
"- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation"
349+
"- **Documentation**: Check `src/maxtext/trainers/post_train/rl/train_rl.py` for the `rl_train` function implementation"
348350
]
349351
}
350352
],
351353
"metadata": {
352354
"kernelspec": {
353-
"display_name": "Python 3",
355+
"display_name": "maxtext_venv",
354356
"language": "python",
355357
"name": "python3"
356358
},
@@ -364,7 +366,7 @@
364366
"name": "python",
365367
"nbconvert_exporter": "python",
366368
"pygments_lexer": "ipython3",
367-
"version": "3.12.11"
369+
"version": "3.12.12"
368370
}
369371
},
370372
"nbformat": 4,

src/maxtext/examples/sft_qwen3_demo.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@
377377
"config = pyconfig.initialize(\n",
378378
" [\n",
379379
" \"\",\n",
380-
" f\"{MAXTEXT_PKG_DIR}/configs/sft.yml\",\n",
380+
" f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n",
381381
" f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
382382
" f\"model_name={MODEL_NAME}\",\n",
383383
" f\"hf_access_token={HF_TOKEN}\",\n",

src/maxtext/examples/sft_train_and_evaluate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def create_vllm_rollout(config, model, mesh, tokenizer):
301301
rollout_vllm_hbm_utilization=0.2,
302302
rollout_vllm_init_with_random_weights=True,
303303
rollout_vllm_tpu_backend_type="jax",
304+
data_type="bfloat16",
304305
),
305306
)
306307

src/maxtext/rl/evaluate_rl.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shim for RL Evaluation in `src/maxtext/trainers/post_train/rl`."""
16+
17+
import importlib
18+
19+
from maxtext.utils import max_logging
20+
21+
OLD_MODULE_PATH = "MaxText.rl.evaluate_rl"
22+
NEW_MODULE_PATH = "maxtext.trainers.post_train.rl.evaluate_rl"
23+
24+
max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n")
25+
_new_module = importlib.import_module(NEW_MODULE_PATH)
26+
27+
evaluate = _new_module.evaluate
28+
generate_responses = _new_module.generate_responses
29+
score_responses = _new_module.score_responses

src/maxtext/rl/train_rl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shim for RL Trainer in `src/maxtext/trainers/post_train/rl`."""
16+
17+
import sys
18+
import importlib
19+
20+
from absl import logging
21+
22+
from maxtext.utils import max_logging
23+
24+
OLD_MODULE_PATH = "MaxText.rl.train_rl"
25+
NEW_MODULE_PATH = "maxtext.trainers.post_train.rl.train_rl"
26+
27+
if __name__ == "__main__":
28+
try:
29+
logging.set_verbosity(logging.INFO)
30+
_new_module = importlib.import_module(NEW_MODULE_PATH)
31+
if hasattr(_new_module, "main"):
32+
max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n")
33+
_new_module.main(sys.argv)
34+
except ImportError as e:
35+
max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n")
36+
raise e

0 commit comments

Comments
 (0)