Skip to content

Commit 828b668

Browse files
Merge pull request #2975 from AI-Hypercomputer:rl_commit
PiperOrigin-RevId: 860239730
2 parents 56bcd76 + d53b80f commit 828b668

7 files changed

Lines changed: 268 additions & 94 deletions

File tree

.github/workflows/build_and_test_maxtext.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ jobs:
115115
device_name: v6e-4
116116
image_type: ${{ matrix.image_type }}
117117
cloud_runner: linux-x86-ct6e-180-4tpu
118+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
118119
secrets:
119120
HF_TOKEN: ${{ secrets.HF_TOKEN }}
120121

@@ -139,6 +140,7 @@ jobs:
139140
is_scheduled_run: ${{ github.event_name == 'schedule' }}
140141
worker_group: ${{ matrix.worker_group }}
141142
total_workers: 2
143+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
142144

143145
maxtext_tpu_unit_tests:
144146
needs: build_and_upload_maxtext_package
@@ -158,6 +160,7 @@ jobs:
158160
tf_force_gpu_allow_growth: false
159161
container_resource_option: "--privileged"
160162
is_scheduled_run: ${{ github.event_name == 'schedule' }}
163+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
161164

162165
maxtext_tpu_integration_tests:
163166
needs: build_and_upload_maxtext_package
@@ -177,6 +180,7 @@ jobs:
177180
tf_force_gpu_allow_growth: false
178181
container_resource_option: "--privileged"
179182
is_scheduled_run: ${{ github.event_name == 'schedule' }}
183+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
180184

181185
maxtext_tpu_pathways_unit_tests:
182186
needs: build_and_upload_maxtext_package
@@ -196,6 +200,7 @@ jobs:
196200
tf_force_gpu_allow_growth: false
197201
container_resource_option: "--privileged"
198202
is_scheduled_run: ${{ github.event_name == 'schedule' }}
203+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
199204

200205
maxtext_tpu_pathways_integration_tests:
201206
needs: build_and_upload_maxtext_package
@@ -215,6 +220,7 @@ jobs:
215220
tf_force_gpu_allow_growth: false
216221
container_resource_option: "--privileged"
217222
is_scheduled_run: ${{ github.event_name == 'schedule' }}
223+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
218224

219225
maxtext_gpu_unit_tests:
220226
needs: build_and_upload_maxtext_package
@@ -235,6 +241,7 @@ jobs:
235241
tf_force_gpu_allow_growth: true
236242
container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged"
237243
is_scheduled_run: ${{ github.event_name == 'schedule' }}
244+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
238245

239246
maxtext_gpu_integration_tests:
240247
needs: build_and_upload_maxtext_package
@@ -255,6 +262,7 @@ jobs:
255262
tf_force_gpu_allow_growth: true
256263
container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged"
257264
is_scheduled_run: ${{ github.event_name == 'schedule' }}
265+
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}
258266

259267
all_tests_passed:
260268
name: All Required Tests Passed

.github/workflows/build_package.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,28 @@ on:
2929
cloud_runner:
3030
required: false
3131
type: string
32+
outputs:
33+
maxtext_sha:
34+
description: "MaxText short SHA used for the build"
35+
value: ${{ jobs.build_and_upload.outputs.maxtext_sha }}
3236

3337
permissions:
3438
contents: read
3539
jobs:
3640
build_and_upload:
3741
runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }}
3842
container: python:3.12.3-slim-bullseye
43+
outputs:
44+
maxtext_sha: ${{ steps.vars.outputs.maxtext_sha }}
3945
steps:
40-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
46+
- name: Checkout MaxText
47+
uses: actions/checkout@v5
48+
- name: Get metadata
49+
id: vars
50+
shell: bash
51+
run: |
52+
# MaxText SHA used to build the package
53+
echo "maxtext_sha=${GITHUB_SHA}" >> $GITHUB_OUTPUT
4154
- name: Install build tools
4255
run: |
4356
python -m pip install --upgrade pip build uv

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ on:
3131
cloud_runner:
3232
required: false
3333
type: string
34+
maxtext_sha:
35+
required: true
36+
type: string
3437
secrets:
3538
HF_TOKEN:
3639
required: true
@@ -43,7 +46,10 @@ jobs:
4346
container:
4447
image: gcr.io/tpu-prod-env-multipod/maxtext-unit-test-${{ inputs.device_type == 'cpu' && 'tpu' || inputs.device_type }}:${{ inputs.image_type != '' && inputs.image_type }}
4548
steps:
46-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
49+
- name: Checkout MaxText
50+
uses: actions/checkout@v5
51+
with:
52+
ref: ${{ inputs.maxtext_sha }}
4753
- name: Download the MaxText wheel
4854
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0
4955
with:
@@ -64,7 +70,8 @@ jobs:
6470
.venv/bin/python3 -m ipykernel install --user --name maxtext_venv
6571
6672
# Install Tunix for post-training notebooks
67-
uv pip install git+https://github.com/google/tunix
73+
git clone https://github.com/google/tunix
74+
uv pip install ./tunix
6875
6976
# Install vllm for post-training notebooks
7077
git clone https://github.com/vllm-project/vllm.git
@@ -95,6 +102,20 @@ jobs:
95102
96103
.venv/bin/papermill "$notebook" "$output_name" -k maxtext_venv
97104
done
105+
- name: Record Commit IDs
106+
shell: bash
107+
run: |
108+
echo "--- MaxText and Post-Training Repositories Commit IDs ---"
109+
echo "maxtext: ${GITHUB_SHA:0:7}"
110+
111+
declare -a repos=("tunix" "vllm" "tpu-inference")
112+
for repo_dir in "${repos[@]}"; do
113+
if [ -d "$repo_dir" ]; then
114+
echo "$repo_dir: $(git -C "$repo_dir" rev-parse --short HEAD)"
115+
else
116+
echo "Warning: $repo_dir directory not found."
117+
fi
118+
done
98119
- name: Upload Outputs
99120
if: always()
100121
uses: actions/upload-artifact@v4

.github/workflows/run_pathways_tests.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ on:
5050
cloud_runner:
5151
required: false
5252
type: string
53+
maxtext_sha:
54+
required: true
55+
type: string
5356

5457
permissions:
5558
contents: read
@@ -67,7 +70,10 @@ jobs:
6770
JAX_BACKEND_TARGET: "grpc://localhost:29000"
6871
options: ${{ inputs.container_resource_option }}
6972
steps:
70-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
73+
- name: Checkout MaxText
74+
uses: actions/checkout@v5
75+
with:
76+
ref: ${{ inputs.maxtext_sha }}
7177
- name: Download the maxtext wheel
7278
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
7379
with:

.github/workflows/run_tests_against_package.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ on:
5858
required: false
5959
type: number
6060
default: 1
61+
maxtext_sha:
62+
required: true
63+
type: string
6164

6265
permissions:
6366
contents: read
@@ -74,7 +77,10 @@ jobs:
7477
ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency)
7578
options: ${{ inputs.container_resource_option }}
7679
steps:
77-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
80+
- name: Checkout MaxText
81+
uses: actions/checkout@v5
82+
with:
83+
ref: ${{ inputs.maxtext_sha }}
7884
- name: Download the maxtext wheel
7985
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
8086
with:

docs/tutorials/posttraining/rl.md

Lines changed: 88 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,39 @@
1616

1717
# Reinforcement Learning on single-host TPUs
1818

19-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B-IT model on the GSM8K math reasoning dataset using a single host TPU-VM such as `v6e-8/v5p-8`.
20-
21-
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
22-
23-
* **Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
24-
25-
* **Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
26-
27-
For efficient model inference and response generation during this process, we rely on the vLLM library.
19+
This tutorial demonstrates step-by-step instructions for setting up the
20+
environment and then training the Llama3.1 8B-IT model on the GSM8K math
21+
reasoning dataset using a single host TPU-VM such as `v6e-8/v5p-8`.
22+
23+
We utilize two RL algorithms, implemented via the Tunix library, to enhance the
24+
model's reasoning capabilities:
25+
26+
- **Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm
27+
designed to enhance the reasoning abilities of LLMs. It is a variant of
28+
Proximal Policy Optimization (PPO) that reduces memory usage by eliminating
29+
the need for a separate value function model. GRPO works by generating
30+
multiple responses for a given prompt, evaluating these responses using a
31+
reward model, and then calculating a relative advantage based on the group's
32+
performance to update the policy.
33+
34+
- **Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that
35+
improves training efficiency and performance of LLMs by using sequence-level
36+
importance ratios and operations. GSPO defines the importance ratio based on
37+
sequence likelihood and performs sequence-level clipping, rewarding, and
38+
optimization.
39+
40+
For efficient model inference and response generation during this process, we
41+
rely on the vLLM library.
2842

2943
Let's get started!
3044

3145
## Create virtual environment and Install MaxText dependencies
32-
If you have already completed the [MaxText installation](../../install_maxtext.md), you can skip to the next section for post-training dependencies installations. Otherwise, please install `MaxText` using the following commands before proceeding.
46+
47+
If you have already completed the
48+
[MaxText installation](../../install_maxtext.md), you can skip to the next
49+
section for post-training dependencies installations. Otherwise, please install
50+
`MaxText` using the following commands before proceeding.
51+
3352
```bash
3453
# 1. Clone the repository
3554
git clone https://github.com/AI-Hypercomputer/maxtext.git
@@ -50,20 +69,48 @@ install_maxtext_github_deps
5069

5170
### Option 1: From PyPI releases
5271

53-
> **Caution:** RL in MaxText is currently broken with PyPI releases of post-training dependencies. We are working on fixing this and recommend following [Option 2: From Github](#option-2-from-github) in the meantime.
72+
> **Caution:** RL in MaxText is currently broken with PyPI releases of
73+
> post-training dependencies. We are working on fixing this and recommend
74+
> following [Option 2: From Github](#option-2-from-github) in the meantime.
5475
55-
Next, run the following bash script to get all the necessary installations inside the virtual environment (for e.g., `maxtext_venv`).
56-
This will take few minutes. Follow along the installation logs and look out for any issues!
76+
Next, run the following bash script to get all the necessary installations
77+
inside the virtual environment (for e.g., `maxtext_venv`). This will take few
78+
minutes. Follow along the installation logs and look out for any issues!
5779

5880
```
5981
bash tools/setup/setup_post_training_requirements.sh
6082
```
6183

62-
Primarily, it installs `Tunix`, and `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
84+
Primarily, it installs `Tunix`, and `vllm-tpu` which is
85+
[vllm](https://github.com/vllm-project/vllm) and
86+
[tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby
87+
providing TPU inference for vLLM, with unified JAX and PyTorch support.
6388

6489
### Option 2: From Github
6590

66-
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source).
91+
You can also locally git clone [tunix](https://github.com/google/tunix) and
92+
install using the instructions
93+
[here](https://github.com/google/tunix?tab=readme-ov-file#installation).
94+
Similarly install [vllm](https://github.com/vllm-project/vllm) and
95+
[tpu-inference](https://github.com/vllm-project/tpu-inference) from source
96+
following the instructions
97+
[here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source).
98+
To get a set of compatible commit IDs for `maxtext`, `tunix`, `tpu-inference`,
99+
and `vllm`, follow these steps:
100+
101+
1. Navigate to the
102+
[MaxText Package Tests](https://github.com/AI-Hypercomputer/maxtext/actions/workflows/build_and_test_maxtext.yml?query=event%3Aschedule)
103+
GitHub Actions workflow.
104+
105+
1. Select the latest successful run.
106+
107+
1. Within the workflow run, find and click on the `maxtext_jupyter_notebooks (py312)` job, then expand the `run` job.
108+
109+
1. Locate the `Record Commit IDs` step. The commit SHAs for `maxtext`, `tunix`,
110+
`tpu-inference`, and `vllm` that were used in that successful run are listed
111+
in the logs of this step.
112+
113+
1. Prior to installation, ensure that the `maxtext`, `tunix`, `vllm`, and `tpu-inference` repositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit: `git checkout <commit_id>`.
67114

68115
## Setup environment variables
69116

@@ -86,16 +133,24 @@ export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
86133

87134
### Option 1: Using an existing MaxText checkpoint
88135

89-
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
136+
If you already have a MaxText-compatible model checkpoint, simply set the
137+
following environment variable and move on to the next section.
138+
90139
```bash
91140
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
92141
```
93142

94143
### Option 2: Converting from a Hugging Face checkpoint
95144

96-
Otherwise, you can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
145+
Otherwise, you can convert a Hugging Face checkpoint to MaxText format using the
146+
`src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you
147+
have a pre-trained model from Hugging Face that you want to use with MaxText.
97148

98-
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
149+
First, ensure you have the necessary dependencies installed. Then, run the
150+
conversion script on a CPU machine. For large models, it is recommended to use
151+
the `--lazy_load_tensors` flag to reduce memory usage during conversion. This
152+
command will download the Hugging Face model and convert it to the MaxText
153+
format, saving it to the specified GCS bucket.
99154

100155
```bash
101156
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
@@ -117,13 +172,13 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
117172
--lazy_load_tensors=true
118173
```
119174

120-
The converted checkpoint will be saved at the following location. Set this environment variable to use it in the following GRPO/GSPO training sessions:
175+
The converted checkpoint will be saved at the following location. Set this
176+
environment variable to use it in the following GRPO/GSPO training sessions:
177+
121178
```bash
122179
export MAXTEXT_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/0/items
123180
```
124181

125-
126-
127182
## Run GRPO
128183

129184
Run the following command for GRPO:
@@ -140,10 +195,12 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
140195

141196
The overview of what this run will do is as follows:
142197

143-
1. We load a policy model and a reference model. Both are copies of the model checkpoint you specified (e.g., `Llama3.1-8b-Instruct`).
144-
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
145-
3. Train the policy model using GRPO.
146-
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.
198+
1. We load a policy model and a reference model. Both are copies of the model
199+
checkpoint you specified (e.g., `Llama3.1-8b-Instruct`).
200+
1. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
201+
1. Train the policy model using GRPO.
202+
1. Evaluate the policy model's performance on GSM8K math reasoning benchmark
203+
after the post-training with GRPO.
147204

148205
## Run GSPO
149206

@@ -162,8 +219,9 @@ python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
162219

163220
The overview of what this run will do is as follows:
164221

165-
1. We load a policy model and a reference model. Both are copies of the model checkpoint you specified (e.g., `Llama3.1-8b-Instruct`).
166-
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
167-
3. Train the policy model using GSPO.
168-
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GSPO.
169-
222+
1. We load a policy model and a reference model. Both are copies of the model
223+
checkpoint you specified (e.g., `Llama3.1-8b-Instruct`).
224+
1. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
225+
1. Train the policy model using GSPO.
226+
1. Evaluate the policy model's performance on GSM8K math reasoning benchmark
227+
after the post-training with GSPO.

0 commit comments

Comments
 (0)