Skip to content

Commit 9be4a6a

Browse files
committed
Move src/MaxText/sft to src/maxtext/trainers/post_train/sft
1 parent 8204907 commit 9be4a6a

13 files changed

Lines changed: 284 additions & 208 deletions

File tree

codecov.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ ignore:
3939
- "src/MaxText/inference"
4040
- "src/MaxText/inference_mlperf"
4141
- "src/MaxText/scratch_code"
42-
- "src/MaxText/distillation" # code moved to src/MaxText/trainers/post_train/distillation
42+
- "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation
43+
- "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft
4344

4445

4546
flags:

docs/tutorials/posttraining/sft.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
-->
1616

1717
# SFT on single-host TPUs
18+
1819
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
1920

2021
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.
@@ -64,16 +65,19 @@ export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['messages']
6465
```
6566

6667
## Get your model checkpoint
68+
6769
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
6870

6971
### Option 1: Using an existing MaxText checkpoint
72+
7073
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
7174

7275
```sh
7376
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
7477
```
7578

7679
### Option 2: Converting a Hugging Face checkpoint
80+
7781
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
7882

7983
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:
@@ -101,10 +105,11 @@ export PRE_TRAINED_MODEL_CKPT_PATH=${PRE_TRAINED_MODEL_CKPT_DIRECTORY}/0/items
101105
```
102106

103107
## Run SFT on Hugging Face Dataset
108+
104109
Now you are ready to run SFT using the following command:
105110

106111
```sh
107-
python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \
112+
python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \
108113
run_name=${RUN_NAME} \
109114
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
110115
model_name=${PRE_TRAINED_MODEL} \
@@ -118,4 +123,5 @@ python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \
118123
train_data_columns=${TRAIN_DATA_COLUMNS} \
119124
profiler=xplane
120125
```
126+
121127
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.

docs/tutorials/posttraining/sft_on_multi_host.md

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
-->
1616

1717
# SFT on multi-host TPUs
18+
1819
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.
1920

2021
This tutorial demonstrates step-by-step instructions for setting up the multi-host TPU environment and then training the model on the Hugging Face dataset using SFT. In this tutorial we use a multi-host TPU such as `v6e-256`.
@@ -24,16 +25,20 @@ We use [Tunix](https://github.com/google/tunix), a JAX-based library designed fo
2425
Let's get started!
2526

2627
## 1. Build and upload MaxText Docker image
28+
2729
This section guides you through cloning the MaxText repository, building MaxText Docker image with dependencies, and uploading the docker image to your project's Artifact Registry.
2830

2931
### 1.1. Clone the MaxText repository
32+
3033
```bash
3134
git clone https://github.com/google/maxtext.git
3235
cd maxtext
3336
```
3437

3538
### 1.2. Build MaxText Docker image
39+
3640
Before building the Docker image, authenticate to [Google Artifact Registry](https://docs.cloud.google.com/artifact-registry/docs/docker/authentication#gcloud-helper) for permission to push your images and other access.
41+
3742
```bash
3843
# Authenticate your user account for gcloud CLI access
3944
gcloud auth login
@@ -43,26 +48,34 @@ gcloud auth application-default login
4348
gcloud auth configure-docker
4449
docker run hello-world
4550
```
51+
4652
Then run the following command to create a local Docker image named `maxtext_base_image`. This build process takes approximately 10 to 15 minutes.
53+
4754
```bash
4855
bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training
4956
```
5057

5158
### 1.3. Upload the Docker image to Artifact Registry
59+
5260
> **Note:** You will need the [**Artifact Registry Writer**](https://docs.cloud.google.com/artifact-registry/docs/access-control#permissions) role to push Docker images to your project's Artifact Registry and to allow the cluster to pull them during workload execution. If you don't have this permission, contact your project administrator to grant you this role through "Google Cloud Console -> IAM -> Grant access".
61+
5362
```bash
5463
export DOCKER_IMAGE_NAME=<Docker Image Name>
5564
bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=$DOCKER_IMAGE_NAME
5665
```
66+
5767
The `docker_upload_runner.sh` script uploads your Docker image to Artifact Registry.
5868

5969
## 2. Install XPK
60-
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md).
70+
71+
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk/blob/main/docs/installation.md).
6172

6273
## 3. Create GKE cluster
74+
6375
Use a pathways ready GKE cluster as described [here](https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster).
6476

6577
## 4. Environment configuration
78+
6679
```bash
6780
# -- Google Cloud Configuration --
6881
export PROJECT=<Google Cloud Project ID>
@@ -91,19 +104,24 @@ export TRAIN_DATA_COLUMNS=<Data Columns to Train on> # e.g., ['messages']
91104
```
92105

93106
## 5. Get MaxText model checkpoint
107+
94108
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
95109

96110
### Option 1: Using an existing MaxText checkpoint
111+
97112
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
98113

99114
```bash
100115
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
101116
```
117+
102118
**Note:** Make sure that `MODEL_CHECKPOINT_PATH` has the checkpoints created using the correct storage flags:
103-
* **For SFT with McJAX:** `checkpoint_storage_use_zarr3=True` and `checkpoint_storage_use_ocdbt=True`.
104-
* **For SFT with Pathways:** `checkpoint_storage_use_zarr3=False` and `checkpoint_storage_use_ocdbt=False`.
119+
120+
- **For SFT with McJAX:** `checkpoint_storage_use_zarr3=True` and `checkpoint_storage_use_ocdbt=True`.
121+
- **For SFT with Pathways:** `checkpoint_storage_use_zarr3=False` and `checkpoint_storage_use_ocdbt=False`.
105122

106123
### Option 2: Converting a Hugging Face checkpoint
124+
107125
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
108126

109127
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:
@@ -137,9 +155,11 @@ export MODEL_CHECKPOINT_PATH=${MODEL_CHECKPOINT_DIRECTORY}/0/items
137155
```
138156

139157
## 6. Submit workload on GKE cluster
158+
140159
This section provides the command to run SFT on a GKE cluster.
141160

142161
### 6.1. SFT with Multi-Controller JAX (McJAX)
162+
143163
```bash
144164
xpk workload create \
145165
--cluster=${CLUSTER_NAME} \
@@ -149,11 +169,13 @@ xpk workload create \
149169
--workload=${WORKLOAD_NAME} \
150170
--tpu-type=${TPU_TYPE} \
151171
--num-slices=${TPU_SLICE} \
152-
--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS"
172+
--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS"
153173
```
174+
154175
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.
155176

156177
### 6.2. SFT with Pathways
178+
157179
```bash
158180
xpk workload create-pathways \
159181
--cluster=${CLUSTER_NAME} \
@@ -163,7 +185,7 @@ xpk workload create-pathways \
163185
--workload=${WORKLOAD_NAME} \
164186
--tpu-type=${TPU_TYPE} \
165187
--num-slices=${TPU_SLICE} \
166-
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
188+
--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True"
167189
```
168190

169191
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.

end_to_end/tpu/llama3.1/8b/run_sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ fi
5757
echo "Running fine-tuning on checkpoint: ${PRE_TRAINED_MODEL_CKPT_PATH}"
5858

5959
# Run Supervised Fine-Tuning on MaxText checkpoint using HuggingFaceH4/ultrachat_200k dataset
60-
python3 -m MaxText.sft.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/sft.yml \
60+
python3 -m maxtext.trainers.post_train.sft.train_sft "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/sft.yml \
6161
run_name=${RUN_NAME} base_output_directory=${BASE_OUTPUT_DIRECTORY}/${PRE_TRAINED_MODEL} \
6262
model_name=${PRE_TRAINED_MODEL} load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \
6363
hf_access_token=$HF_TOKEN tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Repository = "https://github.com/AI-Hypercomputer/maxtext.git"
3838
allow-direct-references = true
3939

4040
[tool.hatch.build.targets.wheel]
41-
packages = ["src/MaxText", "src/install_maxtext_extra_deps"]
41+
packages = ["src/MaxText", "src/maxtext", "src/install_maxtext_extra_deps"]
4242

4343
[tool.hatch.build.targets.wheel.hooks.custom]
4444
path = "build_hooks.py"

src/MaxText/examples/sft_llama3_demo.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
"import sys\n",
150150
"import MaxText\n",
151151
"from MaxText import pyconfig\n",
152-
"from MaxText.sft.sft_trainer import train as sft_train\n",
152+
"from maxtext.trainers.post_train.sft import train_sft\n",
153153
"import jax\n",
154154
"from huggingface_hub import login\n",
155155
"\n",
@@ -173,6 +173,7 @@
173173
"cell_type": "code",
174174
"execution_count": null,
175175
"metadata": {},
176+
"outputs": [],
176177
"source": [
177178
"if IN_COLAB:\n",
178179
" HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
@@ -312,7 +313,7 @@
312313
"print(\"=\" * 60)\n",
313314
"\n",
314315
"try:\n",
315-
" trainer, mesh = sft_train(config)\n",
316+
" trainer, mesh = train_sft.train(config)\n",
316317
"\n",
317318
" print(\"\\n\" + \"=\" * 60)\n",
318319
" print(\"✅ Training Completed Successfully!\")\n",

src/MaxText/examples/sft_qwen3_demo.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@
201201
"from MaxText import pyconfig\n",
202202
"from MaxText.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n",
203203
"from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n",
204-
"from MaxText.sft import sft_trainer\n",
204+
"from maxtext.trainers.post_train.sft import train_sft\n",
205205
"\n",
206206
"# Suppress vLLM logging with a severity level below ERROR\n",
207207
"os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
@@ -451,7 +451,7 @@
451451
},
452452
"outputs": [],
453453
"source": [
454-
"trainer, mesh = sft_trainer.setup_trainer_state(config)"
454+
"trainer, mesh = train_sft.setup_trainer_state(config)"
455455
]
456456
},
457457
{
@@ -545,7 +545,7 @@
545545
"outputs": [],
546546
"source": [
547547
"print(\"Starting SFT Training...\")\n",
548-
"trainer = sft_trainer.train_model(config, trainer, mesh)\n",
548+
"trainer = train_sft.train_model(config, trainer, mesh)\n",
549549
"print(\"SFT Training Complete!\")"
550550
]
551551
},

src/MaxText/examples/sft_train_and_evaluate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
from MaxText import pyconfig
9393
from MaxText.input_pipeline import instruction_data_processing
9494
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
95-
from MaxText.sft import sft_trainer
95+
from maxtext.trainers.post_train.sft import train_sft
9696

9797
# Suppress vLLM logging with a severity level below ERROR
9898
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
@@ -330,7 +330,7 @@ def train_and_evaluate(config):
330330
test_dataset = get_test_dataset(config, tokenizer)
331331
test_dataset = test_dataset[:NUM_TEST_SAMPLES]
332332
test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)
333-
trainer, mesh = sft_trainer.setup_trainer_state(config)
333+
trainer, mesh = train_sft.setup_trainer_state(config)
334334
vllm_rollout = create_vllm_rollout(config, trainer.model, mesh, tokenizer)
335335

336336
# 1. Pre-SFT Evaluation
@@ -340,7 +340,7 @@ def train_and_evaluate(config):
340340

341341
# 2. SFT Training
342342
max_logging.log("Starting SFT training...")
343-
trainer = sft_trainer.train_model(config, trainer, mesh)
343+
trainer = train_sft.train_model(config, trainer, mesh)
344344

345345
# 3. Post-SFT Evaluation
346346
max_logging.log("Running Post-SFT evaluation...")

0 commit comments

Comments
 (0)