Skip to content

Commit 339dc7e

Browse files
Merge pull request #2963 from CIeNET-International:charlesli/add_ckpt_doc
PiperOrigin-RevId: 862282930
2 parents e8cbb57 + 47c4264 commit 339dc7e

9 files changed

Lines changed: 268 additions & 321 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ repos:
5959
- id: mdformat
6060
additional_dependencies: [mdformat-myst, mdformat-ruff]
6161
files: (docs/.)
62+
exclude: docs/guides/checkpointing_solutions.md

docs/guides/checkpointing_solutions.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
(checkpointing_solutions)=
2+
23
# Checkpointing
34

45
::::{grid} 1 2 2 2
@@ -24,13 +25,22 @@ Handle preemption and recover training progress.
2425

2526
Optimize storage costs and performance with multi-tier usage.
2627
:::
28+
29+
:::{grid-item-card} 🔁 Checkpoint conversion utilities
30+
:link: checkpointing_solutions/convert_checkpoint
31+
:link-type: doc
32+
33+
Convenient tools to convert between Hugging Face and MaxText checkpoint.
34+
:::
2735
::::
2836

2937
```{toctree}
30-
:hidden:
31-
:maxdepth: 1
32-
38+
---
39+
hidden:
40+
maxdepth: 1
41+
---
3342
checkpointing_solutions/gcs_checkpointing.md
3443
checkpointing_solutions/emergency_checkpointing.md
3544
checkpointing_solutions/multi_tier_checkpointing.md
45+
checkpointing_solutions/convert_checkpoint.md
3646
```
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Checkpoint conversion utilities
2+
3+
This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/utils/ckpt_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
4+
5+
## Supported models
6+
7+
The following models are supported:
8+
9+
| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF |
10+
| :---------------------- | :--------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: |
11+
| **Gemma2** | 2B, 9B, 27B |||||
12+
| **Gemma3** (Multimodal) | 4B, 12B, 27B | - || - ||
13+
| **Llama3.1** | 8B, 70B, 450B |||||
14+
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B |||||
15+
| **Qwen3 MoE** | 30B, 235B, 480B |||||
16+
| **Mixtral** | 8x7B, 8x22B |||||
17+
| **GPT-OSS** | 20B, 120B |||||
18+
| **DeepSeek3** | 671B | - | - || - |
19+
20+
## Prerequisites
21+
22+
- Hugging Face requires Pytorch.
23+
- Hugging Face model checkpoints require local disk space.
24+
- The model files are always downloaded to a disk cache first before being loaded into memory (for more info, please consult Hugging Face [docs](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference)). The default local storage path for Hugging Face models is \$HOME/.cache/huggingface/hub
25+
26+
## Hugging Face to MaxText
27+
28+
Use the `to_maxtext.py` script to convert a Hugging Face model into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to given output directory.
29+
30+
\*\**For a complete example, see the test script at [`end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh) and [`end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh).*
31+
32+
### Usage
33+
34+
First, make sure python3 virtual environment for MaxText is set up and enabled.
35+
36+
```bash
37+
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
38+
pip install uv
39+
uv venv --python 3.12 --seed $VENV_NAME
40+
source $VENV_NAME/bin/activate
41+
```
42+
43+
Second, ensure you have the necessary dependencies installed (PyTorch for the conversion script).
44+
45+
```bash
46+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
47+
```
48+
49+
Third, setup following environment variables for conversion script
50+
51+
```bash
52+
# -- Model configuration --
53+
export HF_MODEL=<Hugging Face Model to be converted to MaxText> # e.g. 'llama3.1-8b-Instruct'
54+
export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repos
55+
56+
# -- MaxText configuration --
57+
export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory
58+
59+
# -- storage and format options
60+
export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
61+
export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
62+
63+
export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load.
64+
```
65+
66+
Finally, run below command to complete the conversion
67+
68+
```bash
69+
python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
70+
model_name=${HF_MODEL} \
71+
hf_access_token=${HF_TOKEN} \
72+
base_output_directory=${MODEL_CHECKPOINT_DIRECTORY} \
73+
scan_layers=True \
74+
use_multimodal=false \
75+
hardware=cpu \
76+
skip_jax_distributed_system=true \
77+
checkpoint_storage_use_zarr3=${USE_ZARR3} \
78+
checkpoint_storage_use_ocdbt=${USE_OCDBT} \
79+
--lazy_load_tensors=${LAZY_LOAD_TENSORS}
80+
```
81+
82+
**Key arguments:**
83+
84+
- `model_name`: The model identifier, which should be defined in `src/MaxText/utils/utils.py`.
85+
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
86+
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
87+
- `hf_access_token`: Your Hugging Face token.
88+
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
89+
- `hardware=cpu`: run the conversion script on a CPU machine.
90+
- `checkpoint_storage_use_zarr3`: # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
91+
- `checkpoint_storage_use_ocdbt`: # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
92+
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. For large models, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
93+
- `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
94+
95+
Above command will download the Hugging Face model to local machine, convert it to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`.
96+
97+
## MaxText to Hugging Face
98+
99+
Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem.
100+
\*\**For a complete example, see the test script at [`end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh).*
101+
102+
### Usage
103+
104+
The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub.
105+
106+
```bash
107+
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
108+
model_name=<MODEL_NAME> \
109+
load_parameters_path=<path-to-maxtext-checkpoint> \
110+
base_output_directory=<path-to-save-converted-checkpoint> \
111+
scan_layers=false \
112+
use_multimodal=false \
113+
hf_access_token=<your-hf-token> \
114+
weight_dtype=bfloat16
115+
```
116+
117+
**Key arguments:**
118+
119+
- `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`).
120+
- `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`).
121+
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
122+
- `hf_access_token`: Your Hugging Face token.
123+
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
124+
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory is `Maxtext/tmp`.
125+
- `weight_dtype`: dtype for MaxText weights. It affects the resulting HF weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion.
126+
127+
## Verifying conversion correctness
128+
129+
To ensure the conversion was successful, you can use the `tests/utils/forward_pass_logit_checker.py` script. It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion.
130+
131+
### Usage
132+
133+
```bash
134+
python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \
135+
tokenizer_path=assets/<tokenizer> \
136+
load_parameters_path=<path-to-maxtext-checkpoint> \
137+
model_name=<MODEL_NAME> \
138+
scan_layers=false \
139+
max_prefill_predict_length=4 \
140+
max_target_length=8 \
141+
use_multimodal=false \
142+
--run_hf_model=True \
143+
--hf_model_path=<path-to-HF-checkpoint> \
144+
--max_kl_div=0.015
145+
```
146+
147+
**Key arguments:**
148+
149+
- `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`).
150+
- `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`).
151+
- `scan_layers`: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).
152+
- `use_multimodal`: Indicates if multimodality is used.
153+
- `--run_hf_model`: Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits.
154+
- `--hf_model_path`: The path to the Hugging Face checkpoint.
155+
- `--max_kl_div`: Max KL divergence tolerance during comparisons.
156+
157+
**Example successful conversion verification:**
158+
159+
Here is part of the output of forward_pass_logit_checker for the gemma2-2b.
160+
161+
```
162+
--- Prompt: What is the ---
163+
164+
--- MaxText model top 10 tokens ---
165+
| Token ID | Token | Score |
166+
|------------|----------------------|------------|
167+
| 5830 | difference | 27.2500 |
168+
| 1963 | best | 26.6250 |
169+
| 5316 | average | 26.3750 |
170+
| 2669 | change | 26.1250 |
171+
| 12070 | percentage | 26.1250 |
172+
| 1618 | value | 25.8750 |
173+
| 1546 | most | 25.7500 |
174+
| 66202 | molar | 25.5000 |
175+
| 3051 | total | 25.5000 |
176+
| 1503 | name | 25.3750 |
177+
178+
179+
--- HF model top 10 tokens ---
180+
| Token ID | Token | Score |
181+
|------------|----------------------|------------|
182+
| 5830 | difference | 27.2500 |
183+
| 1963 | best | 26.6250 |
184+
| 5316 | average | 26.3750 |
185+
| 12070 | percentage | 26.1250 |
186+
| 2669 | change | 26.1250 |
187+
| 1618 | value | 25.8750 |
188+
| 1546 | most | 25.7500 |
189+
| 66202 | molar | 25.5000 |
190+
| 3051 | total | 25.5000 |
191+
| 6187 | purpose | 25.3750 |
192+
193+
194+
--- Similarity Metrics of Top Tokens ---
195+
| Metric | Value |
196+
|--------------------------------|----------------------|
197+
| overlap_count | 9/10 |
198+
| jaccard_similarity | 0.8181818181818182 |
199+
| rank_agreement_percentage | 70.0 |
200+
201+
202+
Average KL divergence per token (D_KL(P_golden || Q_model)): 0.000409
203+
204+
Max KL divergence for a single token in the set: 0.003497
205+
```
206+
207+
______________________________________________________________________
208+
209+
## Adding support for new models
210+
211+
To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files.
212+
213+
1. **Add parameter mappings**:
214+
215+
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
216+
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
217+
218+
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
219+
1. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py), add the new model key in `HF_IDS`.
220+
1. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in ['src/MaxText/configs/models'](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
221+
222+
Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)
223+
224+
## Debugging tips
225+
226+
If the converted checkpoint can not get loaded and got error like: "type \<class 'jax.\_src.core.ShapeDtypeStruct'> is not a valid JAX type."
227+
228+
- **Potential Cause**: The scan_layers flag is set wrong.
229+
230+
If a converted checkpoint loads without errors but produces incorrect output, consider these common issues:
231+
232+
- **Symptom**: The model generates garbage or nonsensical tokens.
233+
234+
- **Potential Cause**: The query/key/value (Q/K/V) or Out vectors weights were likely reshaped incorrectly during conversion.
235+
236+
- **Symptom**: The model generates repetitive text sequences.
237+
238+
- **Potential Cause**: The layer normalization parameters may have been converted incorrectly.

docs/tutorials/posttraining/full_finetuning.md

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ source $VENV_NAME/bin/activate
3939
uv pip install -e .[tpu] --resolution=lowest
4040
install_maxtext_github_deps
4141
```
42+
4243
## Setup environment variables
4344

4445
```sh
@@ -53,40 +54,23 @@ export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
5354
```
5455

5556
## Hugging Face checkpoint to Maxtext checkpoint
57+
5658
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.
5759

5860
### Option 1: Using an existing MaxText checkpoint
61+
5962
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.
6063

6164
```sh
6265
export MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
6366
```
6467

6568
### Option 2: Converting a Hugging Face checkpoint
66-
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.
67-
68-
1. **Set the Output Path:** First, define where the converted MaxText checkpoint will be saved. For example:
69-
70-
```sh
71-
export MODEL_CKPT_DIRECTORY=${BASE_OUTPUT_DIRECTORY}/maxtext-checkpoint
72-
```
7369

74-
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).
70+
Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on.
7571

76-
```sh
77-
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure torch is installed for the conversion script
78-
79-
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
80-
model_name=${MODEL_NAME} \
81-
hf_access_token=${HF_TOKEN} \
82-
base_output_directory=${MODEL_CKPT_DIRECTORY} \
83-
scan_layers=True skip_jax_distributed_system=True
84-
```
85-
86-
3. **Use the Converted Checkpoint:** Set the following environment variable to use the converted checkpoint:
87-
88-
```sh
89-
export MODEL_CKPT_PATH=${MODEL_CKPT_DIRECTORY}/0/items
72+
```bash
73+
export MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # gs://my-bucket/my-checkpoint-directory/0/items
9074
```
9175

9276
## Dataset
@@ -98,7 +82,7 @@ MaxText provides examples to work with [Common Crawl](https://commoncrawl.org/).
9882
Run these steps once per project prior to any local development or cluster experiments.
9983

10084
1. Create two gcs buckets in your project, one for downloading and retrieving the dataset and the other for storing the logs.
101-
2. Download the dataset in your gcs bucket.
85+
1. Download the dataset in your gcs bucket.
10286

10387
MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them.
10488

0 commit comments

Comments
 (0)