You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/posttraining/rl.md
+88-30Lines changed: 88 additions & 30 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,20 +16,39 @@
16
16
17
17
# Reinforcement Learning on single-host TPUs
18
18
19
-
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B-IT model on the GSM8K math reasoning dataset using a single host TPU-VM such as `v6e-8/v5p-8`.
20
-
21
-
We utilize two RL algorithms, implemented via the Tunix library, to enhance the model's reasoning capabilities:
22
-
23
-
***Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group's performance to update the policy.
24
-
25
-
***Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that improves training efficiency and performance of LLMs by using sequence-level importance ratios and operations. GSPO defines the importance ratio based on sequence likelihood and performs sequence-level clipping, rewarding, and optimization.
26
-
27
-
For efficient model inference and response generation during this process, we rely on the vLLM library.
19
+
This tutorial demonstrates step-by-step instructions for setting up the
20
+
environment and then training the Llama3.1 8B-IT model on the GSM8K math
21
+
reasoning dataset using a single host TPU-VM such as `v6e-8/v5p-8`.
22
+
23
+
We utilize two RL algorithms, implemented via the Tunix library, to enhance the
24
+
model's reasoning capabilities:
25
+
26
+
-**Group Relative Policy Optimization (GRPO)**: GRPO is an RL algorithm
27
+
designed to enhance the reasoning abilities of LLMs. It is a variant of
28
+
Proximal Policy Optimization (PPO) that reduces memory usage by eliminating
29
+
the need for a separate value function model. GRPO works by generating
30
+
multiple responses for a given prompt, evaluating these responses using a
31
+
reward model, and then calculating a relative advantage based on the group's
32
+
performance to update the policy.
33
+
34
+
-**Group Sequence Policy Optimization (GSPO)**: GSPO is an RL algorithm that
35
+
improves training efficiency and performance of LLMs by using sequence-level
36
+
importance ratios and operations. GSPO defines the importance ratio based on
37
+
sequence likelihood and performs sequence-level clipping, rewarding, and
38
+
optimization.
39
+
40
+
For efficient model inference and response generation during this process, we
41
+
rely on the vLLM library.
28
42
29
43
Let's get started!
30
44
31
45
## Create virtual environment and Install MaxText dependencies
32
-
If you have already completed the [MaxText installation](../../install_maxtext.md), you can skip to the next section for post-training dependencies installations. Otherwise, please install `MaxText` using the following commands before proceeding.
46
+
47
+
If you have already completed the
48
+
[MaxText installation](../../install_maxtext.md), you can skip to the next
49
+
section for post-training dependencies installations. Otherwise, please install
50
+
`MaxText` using the following commands before proceeding.
> **Caution:** RL in MaxText is currently broken with PyPI releases of post-training dependencies. We are working on fixing this and recommend following [Option 2: From Github](#option-2-from-github) in the meantime.
72
+
> **Caution:** RL in MaxText is currently broken with PyPI releases of
73
+
> post-training dependencies. We are working on fixing this and recommend
74
+
> following [Option 2: From Github](#option-2-from-github) in the meantime.
54
75
55
-
Next, run the following bash script to get all the necessary installations inside the virtual environment (for e.g., `maxtext_venv`).
56
-
This will take few minutes. Follow along the installation logs and look out for any issues!
76
+
Next, run the following bash script to get all the necessary installations
77
+
inside the virtual environment (for e.g., `maxtext_venv`). This will take few
78
+
minutes. Follow along the installation logs and look out for any issues!
Primarily, it installs `Tunix`, and `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
84
+
Primarily, it installs `Tunix`, and `vllm-tpu` which is
85
+
[vllm](https://github.com/vllm-project/vllm) and
86
+
[tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby
87
+
providing TPU inference for vLLM, with unified JAX and PyTorch support.
63
88
64
89
### Option 2: From Github
65
90
66
-
You can also locally git clone [tunix](https://github.com/google/tunix) and install using the instructions [here](https://github.com/google/tunix?tab=readme-ov-file#installation). Similarly install [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) from source following the instructions [here](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source).
91
+
You can also locally git clone [tunix](https://github.com/google/tunix) and
1. Within the workflow run, find and click on the `maxtext_jupyter_notebooks (py312)` job, then expand the `run` job.
108
+
109
+
1. Locate the `Record Commit IDs` step. The commit SHAs for `maxtext`, `tunix`,
110
+
`tpu-inference`, and `vllm` that were used in that successful run are listed
111
+
in the logs of this step.
112
+
113
+
1. Prior to installation, ensure that the `maxtext`, `tunix`, `vllm`, and `tpu-inference` repositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit: `git checkout <commit_id>`.
67
114
68
115
## Setup environment variables
69
116
@@ -86,16 +133,24 @@ export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
86
133
87
134
### Option 1: Using an existing MaxText checkpoint
88
135
89
-
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
136
+
If you already have a MaxText-compatible model checkpoint, simply set the
137
+
following environment variable and move on to the next section.
138
+
90
139
```bash
91
140
export MAXTEXT_CKPT_PATH=<gcs path for MaxText checkpoint># e.g., gs://my-bucket/my-model-checkpoint/0/items
92
141
```
93
142
94
143
### Option 2: Converting from a Hugging Face checkpoint
95
144
96
-
Otherwise, you can convert a Hugging Face checkpoint to MaxText format using the `src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you have a pre-trained model from Hugging Face that you want to use with MaxText.
145
+
Otherwise, you can convert a Hugging Face checkpoint to MaxText format using the
146
+
`src/MaxText/utils/ckpt_conversion/to_maxtext.py` script. This is useful if you
147
+
have a pre-trained model from Hugging Face that you want to use with MaxText.
97
148
98
-
First, ensure you have the necessary dependencies installed. Then, run the conversion script on a CPU machine. For large models, it is recommended to use the `--lazy_load_tensors` flag to reduce memory usage during conversion. This command will download the Hugging Face model and convert it to the MaxText format, saving it to the specified GCS bucket.
149
+
First, ensure you have the necessary dependencies installed. Then, run the
150
+
conversion script on a CPU machine. For large models, it is recommended to use
151
+
the `--lazy_load_tensors` flag to reduce memory usage during conversion. This
152
+
command will download the Hugging Face model and convert it to the MaxText
The converted checkpoint will be saved at the following location. Set this environment variable to use it in the following GRPO/GSPO training sessions:
175
+
The converted checkpoint will be saved at the following location. Set this
176
+
environment variable to use it in the following GRPO/GSPO training sessions:
0 commit comments