|
107 | 107 | "from pathlib import Path\n", |
108 | 108 | "import MaxText\n", |
109 | 109 | "from huggingface_hub import login\n", |
110 | | - "import jax\n", |
111 | 110 | "\n", |
112 | 111 | "# Set up paths (adjust if needed)\n", |
113 | 112 | "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", |
|
127 | 126 | " raise RuntimeError(\"OUTPUT_DIRECTORY is not set\")\n", |
128 | 127 | " \n", |
129 | 128 | "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", |
| 129 | + "if \"MAXTEXT_PKG_DIR\" not in os.environ:\n", |
| 130 | + " os.environ[\"MAXTEXT_PKG_DIR\"] = MAXTEXT_REPO_ROOT\n", |
130 | 131 | "\n", |
131 | 132 | "if HF_TOKEN:\n", |
132 | 133 | " login(token=HF_TOKEN)\n", |
133 | 134 | " print(\"Authenticated with Hugging Face\")\n", |
134 | 135 | "else:\n", |
135 | 136 | " print(\"Authentication failed: Hugging Face token not set\")\n", |
136 | 137 | "\n", |
137 | | - "# Optional: Override training parameters\n", |
138 | | - "LEARNING_RATE = 3e-6\n", |
139 | | - "NUM_GENERATIONS = 2\n", |
140 | | - "GRPO_BETA = 0.08\n", |
141 | | - "GRPO_EPSILON = 0.2\n", |
142 | | - "CHIPS_PER_VM = 1\n", |
143 | 138 | "\n", |
144 | 139 | "print(f\"📁 MaxText Home: {MAXTEXT_REPO_ROOT}\")\n", |
145 | 140 | "print(f\"🤖 Model: {MODEL_NAME}\")\n", |
146 | 141 | "print(f\"📦 Checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", |
147 | 142 | "print(f\"💾 Output: {OUTPUT_DIRECTORY}\")\n", |
148 | 143 | "print(f\"🔑 HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing - set HF_TOKEN env var'}\")\n", |
149 | | - "print(f\"📊 Steps: {STEPS}\")\n", |
150 | 144 | "print(f\"Loss Algorithm : {LOSS_ALGO}\")" |
151 | 145 | ] |
152 | 146 | }, |
|
178 | 172 | "outputs": [], |
179 | 173 | "source": [ |
180 | 174 | "# Build configuration for GRPO training\n", |
181 | | - "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs/rl.yml\")\n", |
| 175 | + "config_file = os.path.join(MAXTEXT_REPO_ROOT, \"configs\", \"rl.yml\")\n", |
182 | 176 | "\n", |
183 | 177 | "# Verify chat template exists\n", |
184 | | - "if not os.path.exists(CHAT_TEMPLATE_PATH)):\n", |
| 178 | + "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", |
185 | 179 | " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", |
186 | 180 | "\n", |
187 | 181 | "# Build argv list for pyconfig.initialize()\n", |
|
195 | 189 | " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", |
196 | 190 | " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", |
197 | 191 | " f\"hf_access_token={HF_TOKEN}\",\n", |
198 | | - " f\"learning_rate={LEARNING_RATE}\",\n", |
199 | | - " f\"num_generations={NUM_GENERATIONS}\",\n", |
200 | | - " f\"grpo_beta={GRPO_BETA}\",\n", |
201 | | - " f\"grpo_epsilon={GRPO_EPSILON}\",\n", |
202 | | - " f\"chips_per_vm={CHIPS_PER_VM}\",\n", |
203 | | - " f\"loss_algo={LOSS_ALGO}\",\n", |
| 192 | + " f\"debug.rl=False\",\n", |
| 193 | + " f\"rl.loss_algo={LOSS_ALGO}\",\n", |
204 | 194 | " \"use_pathways=False\"\n", |
205 | 195 | "]\n", |
206 | 196 | "\n", |
207 | 197 | "# Initialize configuration\n", |
208 | 198 | "print(f\"🔧 Initializing configuration from: {config_file}\")\n", |
209 | | - "config = pyconfig.initialize(config_argv)\n", |
| 199 | + "trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n", |
| 200 | + "\n", |
| 201 | + "rl_train_steps = int(\n", |
| 202 | + " trainer_config.num_batches\n", |
| 203 | + " * trainer_config.rl.num_iterations\n", |
| 204 | + " * trainer_config.train_fraction\n", |
| 205 | + " * trainer_config.num_epoch\n", |
| 206 | + " )\n", |
210 | 207 | "\n", |
211 | 208 | "print(\"\\n✅ Configuration initialized successfully\")\n", |
212 | | - "print(f\"📊 Training steps: {config.steps}\")\n", |
213 | | - "print(f\"📁 Output directory: {config.base_output_directory}\")\n", |
214 | | - "print(f\"🤖 Model: {config.model_name}\")" |
| 209 | + "print(f\"📁 Output directory: {trainer_config.base_output_directory}\")\n", |
| 210 | + "print(f\"🤖 Model: {trainer_config.model_name}\")\n", |
| 211 | + "print(f\"📊 RL Train Steps: {rl_train_steps}\")" |
215 | 212 | ] |
216 | 213 | }, |
217 | 214 | { |
|
224 | 221 | "print(\"\\n\" + \"=\"*80)\n", |
225 | 222 | "print(\"🚀 Starting Training...\")\n", |
226 | 223 | "print(\"=\"*80)\n", |
227 | | - "print(1)\n", |
228 | 224 | "try:\n", |
229 | 225 | " # Call the rl_train function (it handles everything internally)\n", |
230 | | - " rl_train(config)\n", |
| 226 | + " rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)\n", |
231 | 227 | " \n", |
232 | 228 | " print(\"\\n\" + \"=\"*80)\n", |
233 | 229 | " print(\"✅ Training Completed Successfully!\")\n", |
| 230 | + " print(f\"✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\")\n", |
234 | 231 | " print(\"=\"*80)\n", |
235 | | - " print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n", |
236 | | - " print(f\"📊 TensorBoard logs: {config.tensorboard_dir}\")\n", |
| 232 | + " print(f\"📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\")\n", |
| 233 | + " print(f\"📊 TensorBoard logs: {trainer_config.tensorboard_dir}\")\n", |
237 | 234 | " print(f\"🎯 Model ready for inference!\")\n", |
238 | 235 | " \n", |
239 | 236 | "except Exception as e:\n", |
|
264 | 261 | ], |
265 | 262 | "metadata": { |
266 | 263 | "kernelspec": { |
267 | | - "display_name": "Python 3", |
| 264 | + "display_name": "maxtext_venv", |
268 | 265 | "language": "python", |
269 | 266 | "name": "python3" |
270 | 267 | }, |
|
278 | 275 | "name": "python", |
279 | 276 | "nbconvert_exporter": "python", |
280 | 277 | "pygments_lexer": "ipython3", |
281 | | - "version": "3.8.5" |
| 278 | + "version": "3.12.11" |
282 | 279 | } |
283 | 280 | }, |
284 | 281 | "nbformat": 4, |
|
0 commit comments