|
| 1 | +# Checkpoint conversion utilities |
| 2 | + |
| 3 | +This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/utils/ckpt_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats. |
| 4 | + |
| 5 | +## Supported models |
| 6 | + |
| 7 | +The following models are supported: |
| 8 | + |
| 9 | +| Model Family | Sizes | HF $\\to$ Orbax (scan) | HF $\\to$ Orbax (unscan) | Orbax (scan) $\\to$ HF | Orbax (unscan) $\\to$ HF | |
| 10 | +| :---------------------- | :--------------------- | :--------------------: | :----------------------: | :--------------------: | :----------------------: | |
| 11 | +| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | |
| 12 | +| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | |
| 13 | +| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | |
| 14 | +| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | |
| 15 | +| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | |
| 16 | +| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | |
| 17 | +| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | |
| 18 | +| **DeepSeek3** | 671B | - | - | √ | - | |
| 19 | + |
| 20 | +## Prerequisites |
| 21 | + |
| 22 | +- Hugging Face requires Pytorch. |
| 23 | +- Hugging Face model checkpoints require local disk space. |
| 24 | + - The model files are always downloaded to a disk cache first before being loaded into memory (for more info, please consult Hugging Face [docs](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference)). The default local storage path for Hugging Face models is \$HOME/.cache/huggingface/hub |
| 25 | + |
| 26 | +## Hugging Face to MaxText |
| 27 | + |
| 28 | +Use the `to_maxtext.py` script to convert a Hugging Face model into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to given output directory. |
| 29 | + |
| 30 | +\*\**For a complete example, see the test script at [`end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh) and [`end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh).* |
| 31 | + |
| 32 | +### Usage |
| 33 | + |
| 34 | +First, make sure python3 virtual environment for MaxText is set up and enabled. |
| 35 | + |
| 36 | +```bash |
| 37 | +export VENV_NAME=<your virtual env name> # e.g., maxtext_venv |
| 38 | +pip install uv |
| 39 | +uv venv --python 3.12 --seed $VENV_NAME |
| 40 | +source $VENV_NAME/bin/activate |
| 41 | +``` |
| 42 | + |
| 43 | +Second, ensure you have the necessary dependencies installed (PyTorch for the conversion script). |
| 44 | + |
| 45 | +```bash |
| 46 | +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu |
| 47 | +``` |
| 48 | + |
| 49 | +Third, setup following environment variables for conversion script |
| 50 | + |
| 51 | +```bash |
| 52 | +# -- Model configuration -- |
| 53 | +export HF_MODEL=<Hugging Face Model to be converted to MaxText> # e.g. 'llama3.1-8b-Instruct' |
| 54 | +export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repos |
| 55 | + |
| 56 | +# -- MaxText configuration -- |
| 57 | +export MODEL_CHECKPOINT_DIRECTORY=<output directory to store output of checking point> # e.g., gs://my-bucket/my-checkpoint-directory |
| 58 | + |
| 59 | +# -- storage and format options |
| 60 | +export USE_ZARR3=<Flag to use zarr3> # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways. |
| 61 | +export USE_OCDBT=<Flag to use ocdbt> # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways. |
| 62 | + |
| 63 | +export LAZY_LOAD_TENSORS=<Flag to lazy load> # True to use lazy load, False to use eager load. |
| 64 | +``` |
| 65 | + |
| 66 | +Finally, run below command to complete the conversion |
| 67 | + |
| 68 | +```bash |
| 69 | +python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \ |
| 70 | + model_name=${HF_MODEL} \ |
| 71 | + hf_access_token=${HF_TOKEN} \ |
| 72 | + base_output_directory=${MODEL_CHECKPOINT_DIRECTORY} \ |
| 73 | + scan_layers=True \ |
| 74 | + use_multimodal=false \ |
| 75 | + hardware=cpu \ |
| 76 | + skip_jax_distributed_system=true \ |
| 77 | + checkpoint_storage_use_zarr3=${USE_ZARR3} \ |
| 78 | + checkpoint_storage_use_ocdbt=${USE_OCDBT} \ |
| 79 | + --lazy_load_tensors=${LAZY_LOAD_TENSORS} |
| 80 | +``` |
| 81 | + |
| 82 | +**Key arguments:** |
| 83 | + |
| 84 | +- `model_name`: The model identifier, which should be defined in `src/MaxText/utils/utils.py`. |
| 85 | +- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false). |
| 86 | +- `use_multimodal`: Indicates if multimodality is used, important for Gemma3. |
| 87 | +- `hf_access_token`: Your Hugging Face token. |
| 88 | +- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`. |
| 89 | +- `hardware=cpu`: run the conversion script on a CPU machine. |
| 90 | +- `checkpoint_storage_use_zarr3`: # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways. |
| 91 | +- `checkpoint_storage_use_ocdbt`: # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways. |
| 92 | +- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. For large models, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. |
| 93 | +- `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. |
| 94 | + |
| 95 | +Above command will download the Hugging Face model to local machine, convert it to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`. |
| 96 | + |
| 97 | +## MaxText to Hugging Face |
| 98 | + |
| 99 | +Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem. |
| 100 | +\*\**For a complete example, see the test script at [`end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh).* |
| 101 | + |
| 102 | +### Usage |
| 103 | + |
| 104 | +The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub. |
| 105 | + |
| 106 | +```bash |
| 107 | +python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \ |
| 108 | + model_name=<MODEL_NAME> \ |
| 109 | + load_parameters_path=<path-to-maxtext-checkpoint> \ |
| 110 | + base_output_directory=<path-to-save-converted-checkpoint> \ |
| 111 | + scan_layers=false \ |
| 112 | + use_multimodal=false \ |
| 113 | + hf_access_token=<your-hf-token> \ |
| 114 | + weight_dtype=bfloat16 |
| 115 | +``` |
| 116 | + |
| 117 | +**Key arguments:** |
| 118 | + |
| 119 | +- `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`). |
| 120 | +- `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`). |
| 121 | +- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false). |
| 122 | +- `hf_access_token`: Your Hugging Face token. |
| 123 | +- `use_multimodal`: Indicates if multimodality is used, important for Gemma3. |
| 124 | +- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory is `Maxtext/tmp`. |
| 125 | +- `weight_dtype`: dtype for MaxText weights. It affects the resulting HF weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion. |
| 126 | + |
| 127 | +## Verifying conversion correctness |
| 128 | + |
| 129 | +To ensure the conversion was successful, you can use the `tests/utils/forward_pass_logit_checker.py` script. It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion. |
| 130 | + |
| 131 | +### Usage |
| 132 | + |
| 133 | +```bash |
| 134 | +python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \ |
| 135 | + tokenizer_path=assets/<tokenizer> \ |
| 136 | + load_parameters_path=<path-to-maxtext-checkpoint> \ |
| 137 | + model_name=<MODEL_NAME> \ |
| 138 | + scan_layers=false \ |
| 139 | + max_prefill_predict_length=4 \ |
| 140 | + max_target_length=8 \ |
| 141 | + use_multimodal=false \ |
| 142 | + --run_hf_model=True \ |
| 143 | + --hf_model_path=<path-to-HF-checkpoint> \ |
| 144 | + --max_kl_div=0.015 |
| 145 | +``` |
| 146 | + |
| 147 | +**Key arguments:** |
| 148 | + |
| 149 | +- `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`). |
| 150 | +- `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`). |
| 151 | +- `scan_layers`: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false). |
| 152 | +- `use_multimodal`: Indicates if multimodality is used. |
| 153 | +- `--run_hf_model`: Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits. |
| 154 | +- `--hf_model_path`: The path to the Hugging Face checkpoint. |
| 155 | +- `--max_kl_div`: Max KL divergence tolerance during comparisons. |
| 156 | + |
| 157 | +**Example successful conversion verification:** |
| 158 | + |
| 159 | +Here is part of the output of forward_pass_logit_checker for the gemma2-2b. |
| 160 | + |
| 161 | +``` |
| 162 | +--- Prompt: What is the --- |
| 163 | +
|
| 164 | +--- MaxText model top 10 tokens --- |
| 165 | +| Token ID | Token | Score | |
| 166 | +|------------|----------------------|------------| |
| 167 | +| 5830 | difference | 27.2500 | |
| 168 | +| 1963 | best | 26.6250 | |
| 169 | +| 5316 | average | 26.3750 | |
| 170 | +| 2669 | change | 26.1250 | |
| 171 | +| 12070 | percentage | 26.1250 | |
| 172 | +| 1618 | value | 25.8750 | |
| 173 | +| 1546 | most | 25.7500 | |
| 174 | +| 66202 | molar | 25.5000 | |
| 175 | +| 3051 | total | 25.5000 | |
| 176 | +| 1503 | name | 25.3750 | |
| 177 | +
|
| 178 | +
|
| 179 | +--- HF model top 10 tokens --- |
| 180 | +| Token ID | Token | Score | |
| 181 | +|------------|----------------------|------------| |
| 182 | +| 5830 | difference | 27.2500 | |
| 183 | +| 1963 | best | 26.6250 | |
| 184 | +| 5316 | average | 26.3750 | |
| 185 | +| 12070 | percentage | 26.1250 | |
| 186 | +| 2669 | change | 26.1250 | |
| 187 | +| 1618 | value | 25.8750 | |
| 188 | +| 1546 | most | 25.7500 | |
| 189 | +| 66202 | molar | 25.5000 | |
| 190 | +| 3051 | total | 25.5000 | |
| 191 | +| 6187 | purpose | 25.3750 | |
| 192 | +
|
| 193 | +
|
| 194 | +--- Similarity Metrics of Top Tokens --- |
| 195 | +| Metric | Value | |
| 196 | +|--------------------------------|----------------------| |
| 197 | +| overlap_count | 9/10 | |
| 198 | +| jaccard_similarity | 0.8181818181818182 | |
| 199 | +| rank_agreement_percentage | 70.0 | |
| 200 | +
|
| 201 | +
|
| 202 | +Average KL divergence per token (D_KL(P_golden || Q_model)): 0.000409 |
| 203 | +
|
| 204 | +Max KL divergence for a single token in the set: 0.003497 |
| 205 | +``` |
| 206 | + |
| 207 | +______________________________________________________________________ |
| 208 | + |
| 209 | +## Adding support for new models |
| 210 | + |
| 211 | +To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files. |
| 212 | + |
| 213 | +1. **Add parameter mappings**: |
| 214 | + |
| 215 | +- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer. |
| 216 | +- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. |
| 217 | + |
| 218 | +2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. |
| 219 | +1. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py), add the new model key in `HF_IDS`. |
| 220 | +1. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in ['src/MaxText/configs/models'](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture. |
| 221 | + |
| 222 | +Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983) |
| 223 | + |
| 224 | +## Debugging tips |
| 225 | + |
| 226 | +If the converted checkpoint can not get loaded and got error like: "type \<class 'jax.\_src.core.ShapeDtypeStruct'> is not a valid JAX type." |
| 227 | + |
| 228 | +- **Potential Cause**: The scan_layers flag is set wrong. |
| 229 | + |
| 230 | +If a converted checkpoint loads without errors but produces incorrect output, consider these common issues: |
| 231 | + |
| 232 | +- **Symptom**: The model generates garbage or nonsensical tokens. |
| 233 | + |
| 234 | + - **Potential Cause**: The query/key/value (Q/K/V) or Out vectors weights were likely reshaped incorrectly during conversion. |
| 235 | + |
| 236 | +- **Symptom**: The model generates repetitive text sequences. |
| 237 | + |
| 238 | + - **Potential Cause**: The layer normalization parameters may have been converted incorrectly. |
0 commit comments