Skip to content

Commit 9edcdba

Browse files
Merge pull request #2672 from AI-Hypercomputer:rbierneni-qwen3-next-ckpt-conversion
PiperOrigin-RevId: 845423149
2 parents 948fd36 + b6c32b6 commit 9edcdba

13 files changed

Lines changed: 1036 additions & 58 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/bin/bash
2+
3+
# This script validates a pre-converted MaxText checkpoint against its original
4+
# HuggingFace counterpart to ensure numerical correctness.
5+
6+
# ---
7+
# Example Usage:
8+
#
9+
# # (Required) Path to the converted MaxText checkpoint
10+
# export MAXTEXT_CHECKPOINT_PATH=gs://path/to/converted_ckpt/0/items/
11+
#
12+
# # (Optional) Override the default HF model
13+
# export HF_MODEL_PATH=MyCustom/Qwen3-variant
14+
#
15+
# bash end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
16+
# ---
17+
18+
set -ex
19+
20+
# --- Configuration & Input Validation ---
21+
22+
if [ -z "${MAXTEXT_CHECKPOINT_PATH}" ]; then
23+
echo "ERROR: The MAXTEXT_CHECKPOINT_PATH environment variable is not set."
24+
echo "Please set it to the full GCS path of the pre-converted MaxText checkpoint weights."
25+
exit 1
26+
fi
27+
28+
# Set a default for the HF model path if it's not provided by the user
29+
if [ -z "${HF_MODEL_PATH}" ]; then
30+
export HF_MODEL_PATH="Qwen/Qwen3-Next-80B-A3B-Instruct"
31+
echo "HF_MODEL_PATH is not set, using default: ${HF_MODEL_PATH}"
32+
fi
33+
34+
# Install dependencies required for the logit checker.
35+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
36+
37+
# --- Run the Forward Pass Logit Checker ---
38+
39+
echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}"
40+
echo "Against original HF model: ${HF_MODEL_PATH}"
41+
42+
# This command runs the core validation logic.
43+
JAX_PLATFORMS=cpu python3 -m MaxText.tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \
44+
tokenizer_type=huggingface \
45+
tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \
46+
megablox=False \
47+
sparse_matmul=False \
48+
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
49+
model_name=qwen3-next-80b-a3b \
50+
checkpoint_storage_concurrent_gb=1024 \
51+
skip_jax_distributed_system=True \
52+
dtype=float32 \
53+
weight_dtype=float32 \
54+
matmul_precision=highest \
55+
--hf_model_path=${HF_MODEL_PATH} \
56+
--max_kl_div=0.03 \
57+
--run_hf_model=True
58+
59+
echo "Validation complete."
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
Qwen3 Next
2+
=========
3+
4+
Qwen3-Next is Alibaba 80B Mixture-of-Experts (MoE) model (activating only 3B parameters) that features a novel **hybrid attention** architecture combining Gated DeltaNet (linear attention) and Gated Attention (full attention) for massive context scaling. This documentation covers the integration of **Qwen3-Next-80B-A3B** into MaxText:
5+
6+
For more details on the architecture, see the [Qwen3 Technical Blog](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list).
7+
8+
* * * * *
9+
10+
Checkpoint Conversion
11+
---------------------
12+
13+
To get started, you first need a MaxText-compatible checkpoint.
14+
15+
1. **Download the Model**: Download the official model from Hugging Face. You can use a tool like `hf_transfer` for a fast download.
16+
17+
```
18+
# Example for Qwen3-Next-80B-A3B-Instruct
19+
hf_transfer download Qwen/Qwen3-Next-80B-A3B-Instruct --local-dir /path/to/qwen3_next_hf_checkpoint
20+
```
21+
22+
2. **Convert the Checkpoint**: Run the `convert_qwen3_next_scanned.py` script to convert the downloaded Hugging Face weights into the Orbax format required by MaxText.
23+
24+
```
25+
python3 -m MaxText.utils.ckpt_scripts.convert_qwen3_next_scanned \
26+
--base_model_path /path/to/qwen3_next_hf_checkpoint \
27+
--maxtext_model_path gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \
28+
--model_size qwen3-next-80b-a3b
29+
```
30+
31+
* * * * *
32+
33+
Pre-training and Fine-tuning
34+
----------------------------
35+
36+
After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument.
37+
38+
```
39+
python3 -m MaxText.train src/MaxText/configs/base.yml \
40+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
41+
dataset_path=${DATASET_PATH} \
42+
load_parameters_path=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt/0/items \
43+
run_name=qwen3_next_finetuning \
44+
per_device_batch_size=1 \
45+
model_name=qwen3-next-80b-a3b \
46+
steps=500 \
47+
max_target_length=8192 \
48+
ici_fsdp_parallelism=256 \
49+
tokenizer_type=huggingface \
50+
tokenizer_path=src/MaxText/assets/qwen3-tokenizer
51+
52+
```
53+
54+
* * * * *
55+
56+
Correctness Validation
57+
----------------------
58+
59+
To verify that the MaxText implementation is numerically equivalent to the original Hugging Face model, you can run the end-to-end test scripts. These scripts automate the logit comparison test for each model.
60+
61+
Before running, you must set the `MAXTEXT_CHECKPOINT_PATH` environment variable. You can also optionally set `HF_MODEL_PATH` to point to a local copy of the Hugging Face model.
62+
63+
### Qwen3-Next-80B-A3B
64+
65+
Bash
66+
67+
```
68+
# Set the required path to your converted MaxText checkpoint
69+
export MAXTEXT_CHECKPOINT_PATH=gs://your-gcs-bucket/qwen3-next-80b-a3b_maxtext_ckpt/0/items/
70+
71+
# (Optional) Set the path to your local Hugging Face checkpoint
72+
# export HF_MODEL_PATH=/path/to/local/qwen3-next-80b-a3b_hf_checkpoint
73+
74+
# Execute the validation script
75+
bash end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
76+
77+
```
78+
79+
## Supported MoE Strategies
80+
81+
This model implementation supports both **Token Dropping** and **Dropless** strategies for Mixture of Experts routing. Take a look at the MaxText [documentation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/moe_configuration.md) on MoE configs and flags to set based on desired strategy.
82+

src/MaxText/configs/models/qwen3-next-80b-a3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ gdn_chunk_size: 64
4646
# RoPE Settings
4747
rope_max_timescale: 10000000
4848
partial_rotary_factor: 0.25
49+
50+
# General Model Settings
51+
enable_dropout: False

src/MaxText/configs/types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,15 +2094,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20942094
f"({self.interleave_moe_layer_step})"
20952095
)
20962096
if self.decoder_block == DecoderBlockType.QWEN3_NEXT:
2097-
if self.sparse_matmul:
2098-
logger.warning(
2099-
"For Qwen3-Next, sparse_matmul must be False for now. The dense path has been verified against reference. "
2100-
"Forcing to False."
2101-
)
2102-
self.sparse_matmul = False
2097+
if int(self.gdn_num_value_heads) % int(self.gdn_num_key_heads) != 0:
2098+
raise ValueError("gdn_num_value_heads must be divisible by gdn_num_key_heads")
21032099
rotary_dim = int(self.head_dim * self.partial_rotary_factor)
21042100
if rotary_dim % 2 != 0:
21052101
raise ValueError(f"Calculated rotary dimension ({rotary_dim}) must be a multiple of 2.")
2102+
else:
2103+
if self.partial_rotary_factor is not None and self.partial_rotary_factor != 1.0:
2104+
raise ValueError("`partial_rotary_factor` is only effective when `decoder_block` is set to 'qwen3_next'.")
2105+
21062106
tokenizer_path = getattr(self, "tokenizer_path", None)
21072107
if (
21082108
tokenizer_path

src/MaxText/layers/attentions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,9 +1104,6 @@ def __call__(
11041104
bidirectional_mask,
11051105
self.sinks,
11061106
)
1107-
if self.is_qwen3_next:
1108-
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
1109-
out = out * jax.nn.sigmoid(gate)
11101107
if model_mode == MODEL_MODE_PREFILL:
11111108
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
11121109
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
@@ -1115,6 +1112,9 @@ def __call__(
11151112
out = self._maybe_shard_with_logical(out, self.out_axis_names)
11161113
else:
11171114
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
1115+
if self.is_qwen3_next:
1116+
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
1117+
out = out * jax.nn.sigmoid(gate)
11181118
out = self.out_projection(out, out_sharding=out_sharding)
11191119
out = checkpoint_name(out, "out_proj")
11201120
return out, kv_cache

src/MaxText/layers/decoders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""""Module for decoder layers."""
15+
"""Module for decoder layers"""
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

@@ -35,6 +35,7 @@
3535
from MaxText.sharding import create_sharding
3636
from MaxText.inference import page_manager
3737
from MaxText.layers import linears
38+
from MaxText.layers import normalizations
3839
from MaxText.layers import quantizations
3940
from MaxText.layers import pipeline
4041
from MaxText import maxtext_utils
@@ -473,7 +474,6 @@ def get_norm_layer(self, num_features: int):
473474
DecoderBlockType.GEMMA3,
474475
DecoderBlockType.QWEN3,
475476
DecoderBlockType.QWEN3_MOE,
476-
DecoderBlockType.QWEN3_NEXT,
477477
DecoderBlockType.GPT_OSS,
478478
DecoderBlockType.SIMPLE,
479479
DecoderBlockType.SIMPLE_MLP,
@@ -482,6 +482,10 @@ def get_norm_layer(self, num_features: int):
482482
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
483483
elif self.config.decoder_block == DecoderBlockType.GPT3:
484484
return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True)
485+
elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT:
486+
return functools.partial(
487+
normalizations.Qwen3NextRMSNormLinen, num_features=num_features, shard_mode=self.config.shard_mode
488+
)
485489
else:
486490
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
487491

src/MaxText/layers/normalizations.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,11 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array:
196196

197197
inv_norm = jax.lax.rsqrt((x * x).sum(axis=dim, keepdims=True) + jnp.array(eps, dtype=x.dtype))
198198
return x * inv_norm
199+
200+
201+
Qwen3NextRMSNormLinen = nnx_wrappers.to_linen_class(
202+
RMSNorm,
203+
base_metadata_fn=variable_to_logically_partitioned,
204+
scale_init=linen_initializers.zeros,
205+
scale_offset=1.0,
206+
)

src/MaxText/layers/qwen3.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs
324324
self.value_dim = self.head_v_dim * self.num_v_heads
325325
conv_dim = self.key_dim * 2 + self.value_dim
326326
conv_kernel_size = cfg.gdn_conv_kernel_dim
327+
self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads
327328

328329
# Submodule instantiations
329330
self.in_proj_qkvz = linears.DenseGeneral(
@@ -381,33 +382,86 @@ def a_log_init(key, shape, dtype=jnp.float32):
381382
)
382383

383384
def __call__(self, hidden_states: Array) -> Array:
385+
# hidden_states: (B, S, E)
384386
cfg = self.config
387+
batch, seq_len, _ = hidden_states.shape
385388

386389
# =========================================================================
387390
# STEP A: Input Projections
388391
# =========================================================================
389-
# hidden_states shape: (B, S, E)
390-
# qkvz shape: (B, S, 2*key_dim + 2*value_dim)
392+
# qkvz: (B, S, 2 * K_dim + 2 * V_dim)
391393
qkvz = self.in_proj_qkvz(hidden_states)
392-
# ba shape: (B, S, 2*H_v)
394+
# ba: (B, S, 2 * H_v)
393395
ba = self.in_proj_ba(hidden_states)
394396

395-
# q shape: (B, S, key_dim), k shape: (B, S, key_dim), v shape: (B, S, value_dim), z shape: (B, S, value_dim)
396-
q, k, v, z = jnp.split(qkvz, [self.key_dim, 2 * self.key_dim, 2 * self.key_dim + self.value_dim], axis=-1)
397-
# b shape: (B, S, H_v), a shape: (B, S, H_v)
398-
b, a = jnp.split(ba, [self.num_v_heads], axis=-1)
397+
# QKVZ Reshaping and Splitting
398+
# Per-K_head group dim: 2 * D_k + 2 * D_v * V_per_K
399+
new_shape_qkvz = (
400+
batch,
401+
seq_len,
402+
self.num_k_heads, # H_k
403+
2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head,
404+
)
405+
# mixed_qkvz: (B, S, H_k, 2*D_k + 2*D_v*V_per_K)
406+
mixed_qkvz = qkvz.reshape(new_shape_qkvz)
407+
408+
split_indices_qkvz = [
409+
self.head_k_dim, # D_k
410+
2 * self.head_k_dim, # 2 * D_k
411+
2 * self.head_k_dim + (self.v_heads_per_k_head * self.head_v_dim), # 2 * D_k + V_per_K * D_v
412+
]
413+
# query: (B, S, H_k, D_k)
414+
# key: (B, S, H_k, D_k)
415+
# value_raw: (B, S, H_k, V_per_K * D_v)
416+
# z_raw: (B, S, H_k, V_per_K * D_v)
417+
query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3)
418+
419+
# value: (B, S, H_v, D_v)
420+
value = value_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)
421+
# z: (B, S, H_v, D_v)
422+
z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)
423+
424+
# BA Reshaping and Splitting
425+
new_shape_ba = (
426+
batch,
427+
seq_len,
428+
self.num_k_heads, # H_k
429+
2 * self.v_heads_per_k_head,
430+
)
431+
# mixed_ba: (B, S, H_k, 2 * V_per_K)
432+
mixed_ba = ba.reshape(new_shape_ba)
433+
434+
split_indices_ba = [self.v_heads_per_k_head]
435+
# b_raw: (B, S, H_k, V_per_K)
436+
# a_raw: (B, S, H_k, V_per_K)
437+
b_raw, a_raw = jnp.split(mixed_ba, split_indices_ba, axis=3)
438+
439+
# b: (B, S, H_v)
440+
b = b_raw.reshape(batch, seq_len, self.num_v_heads)
441+
# a: (B, S, H_v)
442+
a = a_raw.reshape(batch, seq_len, self.num_v_heads)
443+
444+
# Flatten head dimensions for concatenation before conv
445+
# q: (B, S, K_dim)
446+
q = query.reshape(batch, seq_len, -1)
447+
# k: (B, S, K_dim)
448+
k = key.reshape(batch, seq_len, -1)
449+
# v: (B, S, V_dim)
450+
v = value.reshape(batch, seq_len, -1)
399451

400452
# =========================================================================
401453
# STEP B: 1D Convolution
402454
# =========================================================================
403-
# qkv shape: (B, S, conv_dim)
455+
# conv_dim = 2 * K_dim + V_dim
456+
# qkv: (B, S, 2 * K_dim + V_dim)
404457
qkv = jnp.concatenate([q, k, v], axis=-1)
405458

406459
# TODO(parambole): Implement caching logic for conv_state and recurrent_state
407460

408461
# Input to conv_layer should be (B, S, C)
409462
# qkv_conv shape: (B, S, conv_dim)
410-
qkv_conv = jax.nn.silu(self.conv1d(qkv).astype(jnp.float32)).astype(cfg.dtype)
463+
conv_out = self.conv1d(qkv)
464+
qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype)
411465
# q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim)
412466
q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1)
413467

@@ -450,13 +504,11 @@ def __call__(self, hidden_states: Array) -> Array:
450504
# =========================================================================
451505
# STEP D: Final Output Stage
452506
# =========================================================================
507+
453508
# The normalization and gating is applied per-head on the value dimension.
454-
# We first reshape the `z` tensor to match the multi-head structure of `core_attn_out`.
455-
# z shape from (B, S, value_dim) -> (B, S, H_v, D_v)
456-
z_reshaped = z.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)
457509

458510
# Apply the norm and gate. Output shape: (B, S, H_v, D_v)
459-
gated_output_reshaped = self.norm(core_attn_out, z_reshaped)
511+
gated_output_reshaped = self.norm(core_attn_out, z)
460512

461513
# Reshape back to a single feature dimension for the final projection.
462514
# Shape from (B, S, H_v, D_v) -> (B, S, value_dim)
@@ -506,9 +558,9 @@ def __init__(
506558
cfg = self.config
507559

508560
scaling_factor = self.config.head_dim**-0.5
561+
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
562+
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
509563

510-
inputs_q_shape = (cfg.per_device_batch_size, cfg.max_target_length, cfg.emb_dim)
511-
inputs_kv_shape = (cfg.per_device_batch_size, cfg.max_target_length, cfg.emb_dim)
512564
self.attention = attentions.Attention(
513565
config=cfg,
514566
num_query_heads=cfg.num_query_heads,
@@ -517,8 +569,8 @@ def __init__(
517569
max_target_length=cfg.max_target_length,
518570
max_prefill_predict_length=cfg.max_prefill_predict_length,
519571
attention_kernel=cfg.attention,
520-
inputs_q_shape=inputs_q_shape,
521-
inputs_kv_shape=inputs_kv_shape,
572+
inputs_q_shape=dummy_inputs_shape,
573+
inputs_kv_shape=dummy_inputs_shape,
522574
out_axis_names=(BATCH, LENGTH_NO_EXP, EMBED),
523575
mesh=self.mesh,
524576
dtype=cfg.dtype,

0 commit comments

Comments
 (0)