Skip to content

Commit f7a8266

Browse files
Update readme with model support
Update test scripts & model ReadMe Update flags for script update decode command
1 parent 72e96f5 commit f7a8266

4 files changed

Lines changed: 204 additions & 60 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
4141

4242
## 🔥 Latest news 🔥
4343

44-
* \[February 27, 2026\] New MaxText structure! MaxText has been restructured according to [RESTRUCTURE.md](https://github.com/AI-Hypercomputer/maxtext/blob/1b9e38aa0a19b6018feb3aed757406126b6953a1/RESTRUCTURE.md). Please feel free to share your thoughts and feedback.
44+
* \[March 5, 2026\] [Qwen3-Next](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) is now supported.
45+
* \[February 27, 2026\] New MaxText structure! MaxText has been restructured according to [RESTRUCTURE.md](https://github.com/AI-Hypercomputer/maxtext/blob/1b9e38aa0a19b6018feb3aed757406126b6953a1/RESTRUCTURE.md). Please feel free to share your thoughts and feedback.
4546
* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
4647
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
4748
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available.
Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#!/bin/bash
22

3-
# This script validates a pre-converted MaxText checkpoint against its original
4-
# HuggingFace counterpart to ensure numerical correctness.
3+
# This file is documentation for how to get started with Qwen3 Next.
54

5+
# This file runs Step 1 on CPU.
6+
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
7+
# Scanned format is better for training; unscanned format is better for decoding.
8+
# 2. Run logit check, pre-training, fine-tuning, and decoding.
69
# ---
710
# Example Usage:
811
#
@@ -17,43 +20,41 @@
1720

1821
set -ex
1922

20-
# --- Configuration & Input Validation ---
23+
export MODEL_NAME='qwen3-next-80b-a3b'
24+
export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct'
2125

22-
if [ -z "${MAXTEXT_CHECKPOINT_PATH}" ]; then
23-
echo "ERROR: The MAXTEXT_CHECKPOINT_PATH environment variable is not set."
24-
echo "Please set it to the full GCS path of the pre-converted MaxText checkpoint weights."
25-
exit 1
26-
fi
26+
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
27+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
2728

28-
# Set a default for the HF model path if it's not provided by the user
29-
if [ -z "${HF_MODEL_PATH}" ]; then
30-
export HF_MODEL_PATH="Qwen/Qwen3-Next-80B-A3B-Instruct"
31-
echo "HF_MODEL_PATH is not set, using default: ${HF_MODEL_PATH}"
29+
# Ensure HF_TOKEN is set
30+
if [ -z "${HF_TOKEN}" ]; then
31+
echo "Error: HF_TOKEN environment variable is not set. Please export your Hugging Face token."
32+
echo "Example: export HF_TOKEN=hf_..."
33+
exit 1
3234
fi
3335

34-
# Install dependencies required for the logit checker.
35-
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
36-
37-
# --- Run the Forward Pass Logit Checker ---
38-
39-
echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}"
40-
echo "Against original HF model: ${HF_MODEL_PATH}"
41-
42-
# This command runs the core validation logic.
43-
JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
44-
tokenizer_type=huggingface \
45-
tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \
46-
megablox=False \
47-
sparse_matmul=False \
48-
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
49-
model_name=qwen3-next-80b-a3b \
50-
checkpoint_storage_concurrent_gb=1024 \
51-
skip_jax_distributed_system=True \
52-
dtype=float32 \
53-
weight_dtype=float32 \
54-
matmul_precision=highest \
55-
--hf_model_path=${HF_MODEL_PATH} \
56-
--max_kl_div=0.03 \
57-
--run_hf_model=True
58-
59-
echo "Validation complete."
36+
if [ -z "${BASE_OUTPUT_PATH}" ]; then
37+
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
38+
# this bucket will store all the files generated by MaxText during a run
39+
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
40+
echo "BASE_OUTPUT_PATH is not set"
41+
fi
42+
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
43+
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}
44+
45+
# 1.1 Convert checkpoint to `scanned` format, more suitable for training
46+
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
47+
model_name=qwen3-next-80b-a3b \
48+
base_output_directory=${BASE_OUTPUT_PATH}/scanned \
49+
hf_access_token=${HF_TOKEN} \
50+
scan_layers=true \
51+
use_multimodal=false
52+
53+
# 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding
54+
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
55+
model_name=qwen3-next-80b-a3b \
56+
base_output_directory=${BASE_OUTPUT_PATH}/unscanned \
57+
hf_access_token=${HF_TOKEN} \
58+
scan_layers=false \
59+
use_multimodal=false
60+
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/bin/bash
2+
3+
# This file is documentation for how to get started with Qwen3 Next.
4+
5+
# This file runs Step 2 on v5p-128 on a daily basis.
6+
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
7+
# Scanned format is better for training; unscanned format is better for decoding.
8+
# 2. Run logit check, pretraining, finetuning, and decoding.
9+
10+
# The golden logit can be generated by:
11+
# python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=Qwen/Qwen3-Next-80B-A3B-Instruct --output-path=golden_data_qwen3-next-80b-a3b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16
12+
13+
set -ex
14+
15+
export PYTHONPATH=$PYTHONPATH:$(pwd)/src
16+
17+
export MODEL_NAME='qwen3-next-80b-a3b'
18+
export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct'
19+
20+
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
21+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
22+
23+
# e.g., $HOME/maxtext/src/MaxText
24+
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"
25+
26+
if [ -z "${BASE_OUTPUT_PATH}" ]; then
27+
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
28+
# this bucket will store all the files generated by MaxText during a run
29+
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
30+
echo "BASE_OUTPUT_PATH is not set"
31+
fi
32+
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
33+
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}
34+
35+
# Step 2:
36+
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
37+
# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items
38+
# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
39+
# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test.
40+
SCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/scanned/0/items
41+
UNSCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/unscanned/0/items
42+
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
43+
export DATASET_PATH=gs://maxtext-dataset
44+
45+
# Test whether the forward pass logits match the golden logits
46+
# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
47+
GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl"
48+
if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
49+
GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl"
50+
GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl
51+
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
52+
fi
53+
54+
python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=True ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1.5 --rtol=1.5 --max_kl_div=0.1
55+
56+
# Run pre-training - tokamax_gmm implementation
57+
python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024
58+
59+
# Run fine-tuning - tokamax_gmm implementation
60+
python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024 checkpoint_storage_concurrent_gb=1024
61+
62+
63+
# Run decoding - tokamax_gmm implementation
64+
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
65+
python3 -m maxtext.inference.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=512 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "

tests/end_to_end/tpu/qwen/next/run_qwen3_next.md

Lines changed: 98 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@ For more details on the architecture, see the [Qwen3 Technical Blog](https://qwe
77

88
* * * * *
99

10+
Pre-Training
11+
---------------------
12+
You can train from scratch to generate a new checkpoint. One example command to run pretraining with Qwen3-Next on v5p-64.
13+
14+
```sh
15+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
16+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
17+
run_name=q3_next_pre_training \
18+
per_device_batch_size=1 \
19+
enable_checkpointing=false \
20+
model_name=qwen3-next-80b-a3b \
21+
ici_fsdp_parallelism=-1 \
22+
steps=5 \
23+
max_target_length=1024 \
24+
async_checkpointing=false \
25+
tokenizer_type=huggingface \
26+
tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer \
27+
attention=flash \
28+
dtype=bfloat16 \
29+
weight_dtype=bfloat16 \
30+
megablox=False \
31+
sparse_matmul=False \
32+
dataset_type=synthetic
33+
```
34+
1035
Checkpoint Conversion
1136
---------------------
1237

@@ -22,18 +47,20 @@ To get started, you first need a MaxText-compatible checkpoint.
2247
2. **Convert the Checkpoint**: Run the `convert_qwen3_next_scanned.py` script to convert the downloaded Hugging Face weights into the Orbax format required by MaxText.
2348
2449
```
25-
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_qwen3_next_scanned \
26-
--base_model_path /path/to/qwen3_next_hf_checkpoint \
27-
--maxtext_model_path gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \
28-
--model_size qwen3-next-80b-a3b
50+
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
51+
model_name=qwen3-next-80b-a3b \
52+
base_output_directory=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \
53+
hf_access_token=${HF_TOKEN} \
54+
scan_layers=true \ # Set to false for unscanned checkpoint
55+
use_multimodal=false
2956
```
3057
3158
* * * * *
3259
33-
Pre-training and Fine-tuning
60+
Fine-tuning
3461
----------------------------
3562
36-
After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument.
63+
After converting the checkpoint, you can use it for fine-tuning. The command below is an example for fine-tuning on a v5p-64 slice.
3764
3865
```
3966
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
@@ -43,40 +70,90 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
4370
run_name=qwen3_next_finetuning \
4471
per_device_batch_size=1 \
4572
model_name=qwen3-next-80b-a3b \
46-
steps=500 \
47-
max_target_length=8192 \
48-
ici_fsdp_parallelism=256 \
73+
steps=30 \
74+
max_target_length=4096 \
75+
ici_fsdp_parallelism=-1 \
4976
tokenizer_type=huggingface \
5077
tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer
78+
```
79+
80+
## Decoding
81+
One example command to run decoding with Qwen3-Next on v5p-64 with unscanned checkpoint for fast decoding.
5182
83+
```sh
84+
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
85+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
86+
load_parameters_path=${CONVERTED_CHECKPOINT} \
87+
run_name=q3-next-decode \
88+
per_device_batch_size=1 \
89+
enable_checkpointing=false \
90+
model_name=qwen3-next-80b-a3b \
91+
max_prefill_predict_length=64 \
92+
max_target_length=1024 \
93+
tokenizer_type=huggingface \
94+
tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer \
95+
attention=dot_product \
96+
dtype=bfloat16 \
97+
weight_dtype=bfloat16 \
98+
megablox=False \
99+
sparse_matmul=False \
100+
ici_tensor_parallelism=1 \
101+
ici_fsdp_parallelism=1 \
102+
ici_expert_parallelism=-1 \
103+
prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " \
104+
scan_layers=False
52105
```
53106

54107
* * * * *
55108

56109
Correctness Validation
57110
----------------------
58111

59-
To verify that the MaxText implementation is numerically equivalent to the original Hugging Face model, you can run the end-to-end test scripts. These scripts automate the logit comparison test for each model.
112+
we perform two primary checks:
60113

61-
Before running, you must set the `MAXTEXT_CHECKPOINT_PATH` environment variable. You can also optionally set `HF_MODEL_PATH` to point to a local copy of the Hugging Face model.
114+
* **Logit Comparison**: We compare the logits generated by our implementation against those from a HuggingFace implementation for a set of given prompts.
115+
* **MMLU Score Validation**: We validate the MMLU score against established benchmarks.
62116

63-
### Qwen3-Next-80B-A3B
64-
65-
Bash
117+
One example command to generate golden logits from HuggingFace for Qwen3-Next:
66118

119+
```sh
120+
python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
121+
--model-id=Qwen/Qwen3-Next-80B-A3B-Instruct \
122+
--output-path=golden_Qwen3_Next.jsonl \
123+
--prompts='I love to;Today is a;What is the'
67124
```
68-
# Set the required path to your converted MaxText checkpoint
69-
export MAXTEXT_CHECKPOINT_PATH=gs://your-gcs-bucket/qwen3-next-80b-a3b_maxtext_ckpt/0/items/
70125

71-
# (Optional) Set the path to your local Hugging Face checkpoint
72-
# export HF_MODEL_PATH=/path/to/local/qwen3-next-80b-a3b_hf_checkpoint
126+
You should be able to see logs like below:
127+
128+
```
129+
...
130+
File is stored locally at golden_Qwen3_Next.jsonl.
131+
```
73132

74-
# Execute the validation script
75-
bash tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
133+
Run command below to compare logits between HuggingFace and MaxText.
76134

135+
```sh
136+
python3 -m tests.utils.forward_pass_logit_checker \
137+
src/maxtext/configs/base.yml \
138+
tokenizer_type=huggingface \
139+
tokenizer_path=Qwen/Qwen3-Next-80B-A3B-Instruct \
140+
load_parameters_path=${CONVERTED_CHECKPOINT} \
141+
run_name=forward_pass_test_qwen3_next \
142+
per_device_batch_size=1 \
143+
model_name=qwen3-next-80b-a3b \
144+
max_prefill_predict_length=4 \
145+
max_target_length=4 \
146+
scan_layers=false \
147+
sparse_matmul=False \
148+
dtype=float32 \
149+
activations_in_float32=true \
150+
matmul_precision=high \
151+
--max_kl_div=2e-4 \
152+
--golden_logits_path=${PWD}/golden_Qwen3_Next.jsonl
77153
```
78154

155+
To run MMLU benchmarks and validate the model's performance, follow the instructions provided [here](../../../benchmarks/api_server/README.md).
156+
79157
## Supported MoE Strategies
80158

81159
This model implementation supports both **Token Dropping** and **Dropless** strategies for Mixture of Experts routing. Take a look at the MaxText [documentation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/moe_configuration.md) on MoE configs and flags to set based on desired strategy.
82-

0 commit comments

Comments
 (0)