Skip to content

Commit a89eb2a

Browse files
Merge pull request #3123 from CIeNET-International:charlesli/471046638
PiperOrigin-RevId: 872100729
2 parents 597d16c + b13b343 commit a89eb2a

2 files changed

Lines changed: 13 additions & 2 deletions

File tree

docs/tutorials/posttraining/rl.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,16 @@ export HF_TOKEN=<Hugging Face access token>
127127
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
128128

129129
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
130+
131+
export CHIPS_PER_VM=<the number of chips per VM> # depends on hardware, for v5p this is 4, for v6e this is 8
130132
```
131133

134+
For the value of `CHIPS_PER_VM` on different TPU hardware, refer the official document
135+
136+
- [TPU v5e](https://docs.cloud.google.com/tpu/docs/v5e) (single host, chips_per_vm=8)
137+
- [TPU v5p](https://docs.cloud.google.com/tpu/docs/v5p) (single host, chips_per_vm=4)
138+
- [TPU v6e](https://docs.cloud.google.com/tpu/docs/v6e) (single host, chips_per_vm=8)
139+
132140
## Get your model checkpoint
133141

134142
### Option 1: Using an existing MaxText checkpoint
@@ -159,7 +167,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
159167
load_parameters_path=${MAXTEXT_CKPT_PATH} \
160168
run_name=${RUN_NAME} \
161169
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
162-
hf_access_token=${HF_TOKEN}
170+
hf_access_token=${HF_TOKEN} \
171+
chips_per_vm=${CHIPS_PER_VM}
163172
```
164173

165174
The overview of what this run will do is as follows:
@@ -183,7 +192,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
183192
run_name=${RUN_NAME} \
184193
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
185194
hf_access_token=${HF_TOKEN} \
186-
loss_algo=gspo-token
195+
loss_algo=gspo-token \
196+
chips_per_vm=${CHIPS_PER_VM}
187197
```
188198

189199
The overview of what this run will do is as follows:

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,6 +1809,7 @@ class MaxTextConfig(
18091809
# Reinforcement Learning
18101810
RLHardware,
18111811
VLLM,
1812+
RL,
18121813
RLDataset,
18131814
RLEvaluation,
18141815
Reward,

0 commit comments

Comments
 (0)