|
94 | 94 | }, |
95 | 95 | { |
96 | 96 | "cell_type": "code", |
97 | | - "execution_count": 2, |
| 97 | + "execution_count": null, |
98 | 98 | "metadata": {}, |
99 | 99 | "outputs": [], |
100 | 100 | "source": [ |
|
148 | 148 | "from pathlib import Path\n", |
149 | 149 | "import MaxText\n", |
150 | 150 | "from huggingface_hub import login\n", |
| 151 | + "from etils import epath\n", |
151 | 152 | "import jax\n", |
152 | 153 | "\n", |
153 | | - "from MaxText import max_utils\n", |
154 | | - "from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n", |
| 154 | + "from maxtext.trainers.post_train.rl.train_rl import rl_train, setup_configs_and_devices\n", |
155 | 155 | "\n", |
156 | 156 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n", |
157 | 157 | "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n", |
| 158 | + "# Suppress vLLM logging with a severity level below ERROR\n", |
| 159 | + "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", |
158 | 160 | "\n", |
159 | 161 | "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n", |
160 | 162 | "MAXTEXT_REPO_ROOT = os.sep.join([\"maxtext\" if p == \"MaxText\" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])\n", |
|
243 | 245 | "metadata": {}, |
244 | 246 | "outputs": [], |
245 | 247 | "source": [ |
246 | | - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", |
| 248 | + "if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n", |
247 | 249 | " # install torch for the conversion script\n", |
248 | 250 | " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", |
249 | 251 | "\n", |
|
256 | 258 | " scan_layers=true \\\n", |
257 | 259 | " skip_jax_distributed_system=True\n", |
258 | 260 | "\n", |
259 | | - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", |
260 | | - " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" |
| 261 | + " if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n", |
| 262 | + " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" |
261 | 263 | ] |
262 | 264 | }, |
263 | 265 | { |
|
276 | 278 | "# Load configuration for RL training\n", |
277 | 279 | "config_argv = [\n", |
278 | 280 | " \"\",\n", |
279 | | - " f\"{MAXTEXT_PKG_DIR}/configs/rl.yml\",\n", |
| 281 | + " f\"{MAXTEXT_PKG_DIR}/configs/post_train/rl.yml\",\n", |
280 | 282 | " f\"model_name={MODEL_NAME}\",\n", |
281 | 283 | " f\"tokenizer_path={TOKENIZER_PATH}\",\n", |
282 | 284 | " f\"run_name={RUN_NAME}\",\n", |
|
344 | 346 | "\n", |
345 | 347 | "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html\n", |
346 | 348 | "- **Configuration**: See `src/maxtext/configs/rl.yml` for all available options\n", |
347 | | - "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation" |
| 349 | + "- **Documentation**: Check `src/maxtext/trainers/post_train/rl/train_rl.py` for the `rl_train` function implementation" |
348 | 350 | ] |
349 | 351 | } |
350 | 352 | ], |
351 | 353 | "metadata": { |
352 | 354 | "kernelspec": { |
353 | | - "display_name": "Python 3", |
| 355 | + "display_name": "maxtext_venv", |
354 | 356 | "language": "python", |
355 | 357 | "name": "python3" |
356 | 358 | }, |
|
364 | 366 | "name": "python", |
365 | 367 | "nbconvert_exporter": "python", |
366 | 368 | "pygments_lexer": "ipython3", |
367 | | - "version": "3.12.11" |
| 369 | + "version": "3.12.12" |
368 | 370 | } |
369 | 371 | }, |
370 | 372 | "nbformat": 4, |
|
0 commit comments