Skip to content

Commit 69adf5d

Browse files
Merge pull request #2794 from AI-Hypercomputer:hengtaoguo-grpo
PiperOrigin-RevId: 843444835
2 parents 9f0f6a8 + 7173c14 commit 69adf5d

5 files changed

Lines changed: 40 additions & 25 deletions

File tree

docs/guides/run_python_notebook.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ You can run Python notebooks on a local JupyterLab environment, giving you full
6969

7070
### Step 1: Set Up TPU VM
7171

72-
In Google Cloud Console:
72+
In Google Cloud Console, create a standalone TPU VM:
7373

74-
1.a. **Compute Engine****TPU****Create TPU**
74+
1.a. **Compute Engine****TPUs****Create TPU**
7575

7676
1.b. Example config:
7777
- **Name:** `maxtext-tpu-node`
@@ -118,12 +118,12 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root
118118

119119
### Supervised Fine-Tuning (SFT)
120120

121-
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)
122-
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
121+
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k). This notebook is friendly for beginners and runs successfully on Google Colab's free-tier v5e-1 TPU runtime.
122+
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k). We recommend running this on a v5p-8 TPU VM using the port-forwarding method.
123123

124124
### Reinforcement Learning (GRPO/GSPO) Training
125125

126-
- **`rl_llama3_demo.ipynb`** → GRPO/GSPO training on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)
126+
- **`rl_llama3_demo.ipynb`** → GRPO/GSPO training on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k). We recommend running this on a v5p-8 TPU VM using the port-forwarding method.
127127

128128
## Common Pitfalls & Debugging
129129

docs/install_maxtext.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ seed-env \
122122
--output-dir=generated_gpu_artifacts
123123
```
124124

125-
## 4. Update Project Files
125+
## Step 4: Update Project Files
126126

127127
After generating the new requirements, you need to update the files in the MaxText repository.
128128

@@ -133,7 +133,7 @@ After generating the new requirements, you need to update the files in the MaxTe
133133
2. **Update `extra_deps_from_github.txt` (if necessary):**
134134
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
135135

136-
## 5. Verify the New Dependencies
136+
## Step 5: Verify the New Dependencies
137137

138138
Finally, test that the new dependencies install correctly and that MaxText runs as expected.
139139

@@ -155,4 +155,4 @@ uv pip install -e .[tpu] --resolution=lowest
155155
install_maxtext_github_deps
156156
```
157157

158-
3. **Run tests:** Run MaxText tests to ensure there are no regressions.
158+
3. **Run tests:** Run MaxText tests to ensure there are no regressions.

docs/tutorials/posttraining/rl.md

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ For efficient model inference and response generation during this process, we re
2929
Let's get started!
3030

3131
## Create virtual environment and Install MaxText dependencies
32-
If you have already completed the [MaxText installation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), you can skip to the next section for post-training dependencies installations. Otherwise, please install `MaxText` using the following commands before proceeding.
32+
If you have already completed the [MaxText installation](../../install_maxtext.md), you can skip to the next section for post-training dependencies installations. Otherwise, please install `MaxText` using the following commands before proceeding.
3333
```bash
3434
# 1. Clone the repository
3535
git clone https://github.com/AI-Hypercomputer/maxtext.git
@@ -78,12 +78,20 @@ export HF_TOKEN=<Hugging Face access token>
7878
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
7979

8080
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
81-
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items
8281
```
8382

8483
## Get your model checkpoint
8584

86-
You can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
85+
### Option 1: Using an existing MaxText checkpoint
86+
87+
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
88+
```bash
89+
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
90+
```
91+
92+
### Option 2: Converting from a Hugging Face checkpoint
93+
94+
Otherwise, you can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
8795

8896
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
8997

@@ -93,7 +101,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
93101
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
94102
model_name=${HF_MODEL} \
95103
hf_access_token=${HF_TOKEN} \
96-
base_output_directory=${MAXTEXT_CKPT_PATH} \
104+
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME} \
97105
scan_layers=True hardware=cpu skip_jax_distributed_system=true
98106

99107
# Example of converting Llama3.1-70B using --lazy_load_tensor=true which uses around 86GB of RAM
@@ -107,6 +115,11 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
107115
--lazy_load_tensors=true
108116
```
109117

118+
The converted checkpoint will be saved at the following location. Set this environment variable to use it in the following GRPO/GSPO training sessions:
119+
```bash
120+
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items
121+
```
122+
110123

111124

112125
## Run GRPO
@@ -125,7 +138,7 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
125138

126139
The overview of what this run will do is as follows:
127140

128-
1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`.
141+
1. We load a policy model and a reference model. Both are copies of the model checkpoint you specified (e.g., `Llama3.1-8b-Instruct`).
129142
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
130143
3. Train the policy model using GRPO.
131144
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
@@ -136,18 +149,18 @@ Run the following command for GSPO:
136149

137150
```
138151
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
139-
model_name=llama3.1-8b \
140-
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
141-
load_parameters_path=gs://path/to/checkpoint/0/items \
142-
run_name=$WORKLOAD \
143-
base_output_directory=$OUTPUT_PATH \
144-
hf_access_token=$HF_TOKEN \
152+
model_name=${MODEL} \
153+
tokenizer_path=${TOKENIZER} \
154+
load_parameters_path=${MAXTEXT_CKPT_PATH} \
155+
run_name=${RUN_NAME} \
156+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
157+
hf_access_token=${HF_TOKEN} \
145158
loss_algo=gspo-token
146159
```
147160

148161
The overview of what this run will do is as follows:
149162

150-
1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`.
163+
1. We load a policy model and a reference model. Both are copies of the model checkpoint you specified (e.g., `Llama3.1-8b-Instruct`).
151164
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
152165
3. Train the policy model using GSPO.
153166
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GSPO.

docs/tutorials/posttraining/rl_on_multi_host.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ For efficient model inference and response generation during this process, we re
2929
Let's get started!
3030

3131
## Create virtual environment and Install MaxText dependencies
32-
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
32+
Follow instructions in [Install MaxText](../../install_maxtext.md), but
3333
recommend creating the virtual environment outside the `maxtext` directory.
3434

3535

@@ -93,7 +93,7 @@ You can install the required dependencies using either of the following two opti
9393
### Option 1: Installing stable releases of tunix and vllm-tpu
9494
Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-inference installed.
9595

96-
In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
96+
In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support. This build process takes approximately 10 to 15 minutes.
9797

9898
```
9999
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
@@ -109,13 +109,14 @@ bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training PO
109109
```
110110

111111
### Upload the dependency docker image along with MaxText code
112+
> **Note:** You will need the [**Artifact Registry Writer**](https://docs.cloud.google.com/artifact-registry/docs/access-control#permissions) role to push Docker images to your project's Artifact Registry and to allow the cluster to pull them during workload execution. If you don't have this permission, contact your project administrator to grant you this role through "Google Cloud Console -> IAM -> Grant access".
112113
```
113114
bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}
114115
```
115116

116117
## Submit your RL workload via Pathways
117118

118-
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via [XPK](https://github.com/AI-Hypercomputer/xpk).
119+
Please create a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster), and you can submit the `train_rl.py` script via XPK. You can install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md).
119120

120121
### Submit GRPO workload
121122
```

docs/tutorials/posttraining/sft_on_multi_host.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ gcloud auth application-default login
4343
gcloud auth configure-docker
4444
docker run hello-world
4545
```
46-
Then run the following command to create a local Docker image named `maxtext_base_image`.
46+
Then run the following command to create a local Docker image named `maxtext_base_image`. This build process takes approximately 10 to 15 minutes.
4747
```bash
4848
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
4949
```
5050

5151
### 1.3. Upload the Docker image to Artifact Registry
52+
> **Note:** You will need the [**Artifact Registry Writer**](https://docs.cloud.google.com/artifact-registry/docs/access-control#permissions) role to push Docker images to your project's Artifact Registry and to allow the cluster to pull them during workload execution. If you don't have this permission, contact your project administrator to grant you this role through "Google Cloud Console -> IAM -> Grant access".
5253
```bash
5354
# Replace `$USER_runner` with your desired image name
5455
export DOCKER_IMAGE_NAME=${USER}_runner
@@ -57,7 +58,7 @@ bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=$DOCKER_IMAGE
5758
The `docker_upload_runner.sh` script uploads your Docker image to Artifact Registry.
5859

5960
## 2. Install XPK
60-
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation-via-pip).
61+
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md).
6162

6263
## 3. Create GKE cluster
6364
Use a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster).

0 commit comments

Comments
 (0)