@@ -127,8 +127,16 @@ export HF_TOKEN=<Hugging Face access token>
127127export BASE_OUTPUT_DIRECTORY=< output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
128128
129129export RUN_NAME=< name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
130+
131+ export CHIPS_PER_VM=< the number of chips per VM> # depends on hardware, for v5p this is 4, for v6e this is 8
130132```
131133
134+ For the value of ` CHIPS_PER_VM ` on different TPU hardware, refer the official document
135+
136+ - [ TPU v5e] ( https://docs.cloud.google.com/tpu/docs/v5e ) (single host, chips_per_vm=8)
137+ - [ TPU v5p] ( https://docs.cloud.google.com/tpu/docs/v5p ) (single host, chips_per_vm=4)
138+ - [ TPU v6e] ( https://docs.cloud.google.com/tpu/docs/v6e ) (single host, chips_per_vm=8)
139+
132140## Get your model checkpoint
133141
134142### Option 1: Using an existing MaxText checkpoint
@@ -159,7 +167,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
159167 load_parameters_path=${MAXTEXT_CKPT_PATH} \
160168 run_name=${RUN_NAME} \
161169 base_output_directory=${BASE_OUTPUT_DIRECTORY} \
162- hf_access_token=${HF_TOKEN}
170+ hf_access_token=${HF_TOKEN} \
171+ chips_per_vm=${CHIPS_PER_VM}
163172```
164173
165174The overview of what this run will do is as follows:
@@ -183,7 +192,8 @@ python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \
183192 run_name=${RUN_NAME} \
184193 base_output_directory=${BASE_OUTPUT_DIRECTORY} \
185194 hf_access_token=${HF_TOKEN} \
186- loss_algo=gspo-token
195+ loss_algo=gspo-token \
196+ chips_per_vm=${CHIPS_PER_VM}
187197```
188198
189199The overview of what this run will do is as follows:
0 commit comments