Skip to content

Commit b6c32b6

Browse files
added scanned ckpt script
add debug statements Conversion script ran without failing test verify orbax hf tensors Add unscanned conversion script for qwen3 next Move gating op to after sharding optimizations added zero centered rmsnorm Add layer by layer comparision script Remove debug files Remove zero centered rms norm logic Remove changes from forward pass logit checker Remove sow debug line Fix qkvz split in gated delta net and fix normalization after decoder layers Run linter and modify ckpt conversion config remove scanned script since it is not working yet move qwen3 next unscanned conversion script to utils folder Remove rms norm after decoder block for qwen3 next Add scanned conversion script for qwen3 next Added qwen3 next conversion test script Resolved gemini review comments Ran pyink for indentation errors Added readme for qwen3 next typo in qwen3 next readme Reformatted unscanned script Formatted scripts again Undo changes in decoders.py Formatted function with long line length fix linter issues Revise gemini-review comment Add back change to pyconfig after rebase Resolved pr comments Added moe strategies section to qwen3 next readme resolved comments in scripts Dynamically get batch_size and seq_len Add logic to decouple touple when using scanned Resolve pr comments Add train compile test for qwen3-next Update train_compile test for qwen3-next Moved checks to types.py from pyconfig_deprecated.py Resolved comment for qwen3 next readme Ran pyink formatter Remove sparse_matmul test
1 parent d5ea751 commit b6c32b6

13 files changed

Lines changed: 1036 additions & 58 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/bin/bash
2+
3+
# This script validates a pre-converted MaxText checkpoint against its original
4+
# HuggingFace counterpart to ensure numerical correctness.
5+
6+
# ---
7+
# Example Usage:
8+
#
9+
# # (Required) Path to the converted MaxText checkpoint
10+
# export MAXTEXT_CHECKPOINT_PATH=gs://path/to/converted_ckpt/0/items/
11+
#
12+
# # (Optional) Override the default HF model
13+
# export HF_MODEL_PATH=MyCustom/Qwen3-variant
14+
#
15+
# bash end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
16+
# ---
17+
18+
set -ex
19+
20+
# --- Configuration & Input Validation ---
21+
22+
if [ -z "${MAXTEXT_CHECKPOINT_PATH}" ]; then
23+
echo "ERROR: The MAXTEXT_CHECKPOINT_PATH environment variable is not set."
24+
echo "Please set it to the full GCS path of the pre-converted MaxText checkpoint weights."
25+
exit 1
26+
fi
27+
28+
# Set a default for the HF model path if it's not provided by the user
29+
if [ -z "${HF_MODEL_PATH}" ]; then
30+
export HF_MODEL_PATH="Qwen/Qwen3-Next-80B-A3B-Instruct"
31+
echo "HF_MODEL_PATH is not set, using default: ${HF_MODEL_PATH}"
32+
fi
33+
34+
# Install dependencies required for the logit checker.
35+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
36+
37+
# --- Run the Forward Pass Logit Checker ---
38+
39+
echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}"
40+
echo "Against original HF model: ${HF_MODEL_PATH}"
41+
42+
# This command runs the core validation logic.
43+
JAX_PLATFORMS=cpu python3 -m MaxText.tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \
44+
tokenizer_type=huggingface \
45+
tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \
46+
megablox=False \
47+
sparse_matmul=False \
48+
load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \
49+
model_name=qwen3-next-80b-a3b \
50+
checkpoint_storage_concurrent_gb=1024 \
51+
skip_jax_distributed_system=True \
52+
dtype=float32 \
53+
weight_dtype=float32 \
54+
matmul_precision=highest \
55+
--hf_model_path=${HF_MODEL_PATH} \
56+
--max_kl_div=0.03 \
57+
--run_hf_model=True
58+
59+
echo "Validation complete."
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
Qwen3 Next
2+
=========
3+
4+
Qwen3-Next is Alibaba 80B Mixture-of-Experts (MoE) model (activating only 3B parameters) that features a novel **hybrid attention** architecture combining Gated DeltaNet (linear attention) and Gated Attention (full attention) for massive context scaling. This documentation covers the integration of **Qwen3-Next-80B-A3B** into MaxText:
5+
6+
For more details on the architecture, see the [Qwen3 Technical Blog](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list).
7+
8+
* * * * *
9+
10+
Checkpoint Conversion
11+
---------------------
12+
13+
To get started, you first need a MaxText-compatible checkpoint.
14+
15+
1. **Download the Model**: Download the official model from Hugging Face. You can use a tool like `hf_transfer` for a fast download.
16+
17+
```
18+
# Example for Qwen3-Next-80B-A3B-Instruct
19+
hf_transfer download Qwen/Qwen3-Next-80B-A3B-Instruct --local-dir /path/to/qwen3_next_hf_checkpoint
20+
```
21+
22+
2. **Convert the Checkpoint**: Run the `convert_qwen3_next_scanned.py` script to convert the downloaded Hugging Face weights into the Orbax format required by MaxText.
23+
24+
```
25+
python3 -m MaxText.utils.ckpt_scripts.convert_qwen3_next_scanned \
26+
--base_model_path /path/to/qwen3_next_hf_checkpoint \
27+
--maxtext_model_path gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \
28+
--model_size qwen3-next-80b-a3b
29+
```
30+
31+
* * * * *
32+
33+
Pre-training and Fine-tuning
34+
----------------------------
35+
36+
After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument.
37+
38+
```
39+
python3 -m MaxText.train src/MaxText/configs/base.yml \
40+
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
41+
dataset_path=${DATASET_PATH} \
42+
load_parameters_path=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt/0/items \
43+
run_name=qwen3_next_finetuning \
44+
per_device_batch_size=1 \
45+
model_name=qwen3-next-80b-a3b \
46+
steps=500 \
47+
max_target_length=8192 \
48+
ici_fsdp_parallelism=256 \
49+
tokenizer_type=huggingface \
50+
tokenizer_path=src/MaxText/assets/qwen3-tokenizer
51+
52+
```
53+
54+
* * * * *
55+
56+
Correctness Validation
57+
----------------------
58+
59+
To verify that the MaxText implementation is numerically equivalent to the original Hugging Face model, you can run the end-to-end test scripts. These scripts automate the logit comparison test for each model.
60+
61+
Before running, you must set the `MAXTEXT_CHECKPOINT_PATH` environment variable. You can also optionally set `HF_MODEL_PATH` to point to a local copy of the Hugging Face model.
62+
63+
### Qwen3-Next-80B-A3B
64+
65+
Bash
66+
67+
```
68+
# Set the required path to your converted MaxText checkpoint
69+
export MAXTEXT_CHECKPOINT_PATH=gs://your-gcs-bucket/qwen3-next-80b-a3b_maxtext_ckpt/0/items/
70+
71+
# (Optional) Set the path to your local Hugging Face checkpoint
72+
# export HF_MODEL_PATH=/path/to/local/qwen3-next-80b-a3b_hf_checkpoint
73+
74+
# Execute the validation script
75+
bash end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh
76+
77+
```
78+
79+
## Supported MoE Strategies
80+
81+
This model implementation supports both **Token Dropping** and **Dropless** strategies for Mixture of Experts routing. Take a look at the MaxText [documentation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/moe_configuration.md) on MoE configs and flags to set based on desired strategy.
82+

src/MaxText/configs/models/qwen3-next-80b-a3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ gdn_chunk_size: 64
4646
# RoPE Settings
4747
rope_max_timescale: 10000000
4848
partial_rotary_factor: 0.25
49+
50+
# General Model Settings
51+
enable_dropout: False

src/MaxText/configs/types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,15 +2088,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
20882088
f"({self.interleave_moe_layer_step})"
20892089
)
20902090
if self.decoder_block == DecoderBlockType.QWEN3_NEXT:
2091-
if self.sparse_matmul:
2092-
logger.warning(
2093-
"For Qwen3-Next, sparse_matmul must be False for now. The dense path has been verified against reference. "
2094-
"Forcing to False."
2095-
)
2096-
self.sparse_matmul = False
2091+
if int(self.gdn_num_value_heads) % int(self.gdn_num_key_heads) != 0:
2092+
raise ValueError("gdn_num_value_heads must be divisible by gdn_num_key_heads")
20972093
rotary_dim = int(self.head_dim * self.partial_rotary_factor)
20982094
if rotary_dim % 2 != 0:
20992095
raise ValueError(f"Calculated rotary dimension ({rotary_dim}) must be a multiple of 2.")
2096+
else:
2097+
if self.partial_rotary_factor is not None and self.partial_rotary_factor != 1.0:
2098+
raise ValueError("`partial_rotary_factor` is only effective when `decoder_block` is set to 'qwen3_next'.")
2099+
21002100
tokenizer_path = getattr(self, "tokenizer_path", None)
21012101
if (
21022102
tokenizer_path

src/MaxText/layers/attentions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,9 +1104,6 @@ def __call__(
11041104
bidirectional_mask,
11051105
self.sinks,
11061106
)
1107-
if self.is_qwen3_next:
1108-
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
1109-
out = out * jax.nn.sigmoid(gate)
11101107
if model_mode == MODEL_MODE_PREFILL:
11111108
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
11121109
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
@@ -1115,6 +1112,9 @@ def __call__(
11151112
out = self._maybe_shard_with_logical(out, self.out_axis_names)
11161113
else:
11171114
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
1115+
if self.is_qwen3_next:
1116+
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
1117+
out = out * jax.nn.sigmoid(gate)
11181118
out = self.out_projection(out, out_sharding=out_sharding)
11191119
out = checkpoint_name(out, "out_proj")
11201120
return out, kv_cache

src/MaxText/layers/decoders.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""""Module for decoder layers."""
15+
"""Module for decoder layers"""
1616
# pylint: disable=arguments-differ
1717
# pylint: disable=no-name-in-module
1818

@@ -35,6 +35,7 @@
3535
from MaxText.sharding import create_sharding
3636
from MaxText.inference import page_manager
3737
from MaxText.layers import linears
38+
from MaxText.layers import normalizations
3839
from MaxText.layers import quantizations
3940
from MaxText.layers import pipeline
4041
from MaxText import maxtext_utils
@@ -473,7 +474,6 @@ def get_norm_layer(self, num_features: int):
473474
DecoderBlockType.GEMMA3,
474475
DecoderBlockType.QWEN3,
475476
DecoderBlockType.QWEN3_MOE,
476-
DecoderBlockType.QWEN3_NEXT,
477477
DecoderBlockType.GPT_OSS,
478478
DecoderBlockType.SIMPLE,
479479
DecoderBlockType.SIMPLE_MLP,
@@ -482,6 +482,10 @@ def get_norm_layer(self, num_features: int):
482482
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
483483
elif self.config.decoder_block == DecoderBlockType.GPT3:
484484
return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True)
485+
elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT:
486+
return functools.partial(
487+
normalizations.Qwen3NextRMSNormLinen, num_features=num_features, shard_mode=self.config.shard_mode
488+
)
485489
else:
486490
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
487491

src/MaxText/layers/normalizations.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,11 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array:
196196

197197
inv_norm = jax.lax.rsqrt((x * x).sum(axis=dim, keepdims=True) + jnp.array(eps, dtype=x.dtype))
198198
return x * inv_norm
199+
200+
201+
Qwen3NextRMSNormLinen = nnx_wrappers.to_linen_class(
202+
RMSNorm,
203+
base_metadata_fn=variable_to_logically_partitioned,
204+
scale_init=linen_initializers.zeros,
205+
scale_offset=1.0,
206+
)

src/MaxText/layers/qwen3.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs
324324
self.value_dim = self.head_v_dim * self.num_v_heads
325325
conv_dim = self.key_dim * 2 + self.value_dim
326326
conv_kernel_size = cfg.gdn_conv_kernel_dim
327+
self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads
327328

328329
# Submodule instantiations
329330
self.in_proj_qkvz = linears.DenseGeneral(
@@ -381,33 +382,86 @@ def a_log_init(key, shape, dtype=jnp.float32):
381382
)
382383

383384
def __call__(self, hidden_states: Array) -> Array:
385+
# hidden_states: (B, S, E)
384386
cfg = self.config
387+
batch, seq_len, _ = hidden_states.shape
385388

386389
# =========================================================================
387390
# STEP A: Input Projections
388391
# =========================================================================
389-
# hidden_states shape: (B, S, E)
390-
# qkvz shape: (B, S, 2*key_dim + 2*value_dim)
392+
# qkvz: (B, S, 2 * K_dim + 2 * V_dim)
391393
qkvz = self.in_proj_qkvz(hidden_states)
392-
# ba shape: (B, S, 2*H_v)
394+
# ba: (B, S, 2 * H_v)
393395
ba = self.in_proj_ba(hidden_states)
394396

395-
# q shape: (B, S, key_dim), k shape: (B, S, key_dim), v shape: (B, S, value_dim), z shape: (B, S, value_dim)
396-
q, k, v, z = jnp.split(qkvz, [self.key_dim, 2 * self.key_dim, 2 * self.key_dim + self.value_dim], axis=-1)
397-
# b shape: (B, S, H_v), a shape: (B, S, H_v)
398-
b, a = jnp.split(ba, [self.num_v_heads], axis=-1)
397+
# QKVZ Reshaping and Splitting
398+
# Per-K_head group dim: 2 * D_k + 2 * D_v * V_per_K
399+
new_shape_qkvz = (
400+
batch,
401+
seq_len,
402+
self.num_k_heads, # H_k
403+
2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head,
404+
)
405+
# mixed_qkvz: (B, S, H_k, 2*D_k + 2*D_v*V_per_K)
406+
mixed_qkvz = qkvz.reshape(new_shape_qkvz)
407+
408+
split_indices_qkvz = [
409+
self.head_k_dim, # D_k
410+
2 * self.head_k_dim, # 2 * D_k
411+
2 * self.head_k_dim + (self.v_heads_per_k_head * self.head_v_dim), # 2 * D_k + V_per_K * D_v
412+
]
413+
# query: (B, S, H_k, D_k)
414+
# key: (B, S, H_k, D_k)
415+
# value_raw: (B, S, H_k, V_per_K * D_v)
416+
# z_raw: (B, S, H_k, V_per_K * D_v)
417+
query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3)
418+
419+
# value: (B, S, H_v, D_v)
420+
value = value_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)
421+
# z: (B, S, H_v, D_v)
422+
z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)
423+
424+
# BA Reshaping and Splitting
425+
new_shape_ba = (
426+
batch,
427+
seq_len,
428+
self.num_k_heads, # H_k
429+
2 * self.v_heads_per_k_head,
430+
)
431+
# mixed_ba: (B, S, H_k, 2 * V_per_K)
432+
mixed_ba = ba.reshape(new_shape_ba)
433+
434+
split_indices_ba = [self.v_heads_per_k_head]
435+
# b_raw: (B, S, H_k, V_per_K)
436+
# a_raw: (B, S, H_k, V_per_K)
437+
b_raw, a_raw = jnp.split(mixed_ba, split_indices_ba, axis=3)
438+
439+
# b: (B, S, H_v)
440+
b = b_raw.reshape(batch, seq_len, self.num_v_heads)
441+
# a: (B, S, H_v)
442+
a = a_raw.reshape(batch, seq_len, self.num_v_heads)
443+
444+
# Flatten head dimensions for concatenation before conv
445+
# q: (B, S, K_dim)
446+
q = query.reshape(batch, seq_len, -1)
447+
# k: (B, S, K_dim)
448+
k = key.reshape(batch, seq_len, -1)
449+
# v: (B, S, V_dim)
450+
v = value.reshape(batch, seq_len, -1)
399451

400452
# =========================================================================
401453
# STEP B: 1D Convolution
402454
# =========================================================================
403-
# qkv shape: (B, S, conv_dim)
455+
# conv_dim = 2 * K_dim + V_dim
456+
# qkv: (B, S, 2 * K_dim + V_dim)
404457
qkv = jnp.concatenate([q, k, v], axis=-1)
405458

406459
# TODO(parambole): Implement caching logic for conv_state and recurrent_state
407460

408461
# Input to conv_layer should be (B, S, C)
409462
# qkv_conv shape: (B, S, conv_dim)
410-
qkv_conv = jax.nn.silu(self.conv1d(qkv).astype(jnp.float32)).astype(cfg.dtype)
463+
conv_out = self.conv1d(qkv)
464+
qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype)
411465
# q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim)
412466
q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1)
413467

@@ -450,13 +504,11 @@ def __call__(self, hidden_states: Array) -> Array:
450504
# =========================================================================
451505
# STEP D: Final Output Stage
452506
# =========================================================================
507+
453508
# The normalization and gating is applied per-head on the value dimension.
454-
# We first reshape the `z` tensor to match the multi-head structure of `core_attn_out`.
455-
# z shape from (B, S, value_dim) -> (B, S, H_v, D_v)
456-
z_reshaped = z.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)
457509

458510
# Apply the norm and gate. Output shape: (B, S, H_v, D_v)
459-
gated_output_reshaped = self.norm(core_attn_out, z_reshaped)
511+
gated_output_reshaped = self.norm(core_attn_out, z)
460512

461513
# Reshape back to a single feature dimension for the final projection.
462514
# Shape from (B, S, H_v, D_v) -> (B, S, value_dim)
@@ -506,9 +558,9 @@ def __init__(
506558
cfg = self.config
507559

508560
scaling_factor = self.config.head_dim**-0.5
561+
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
562+
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
509563

510-
inputs_q_shape = (cfg.per_device_batch_size, cfg.max_target_length, cfg.emb_dim)
511-
inputs_kv_shape = (cfg.per_device_batch_size, cfg.max_target_length, cfg.emb_dim)
512564
self.attention = attentions.Attention(
513565
config=cfg,
514566
num_query_heads=cfg.num_query_heads,
@@ -517,8 +569,8 @@ def __init__(
517569
max_target_length=cfg.max_target_length,
518570
max_prefill_predict_length=cfg.max_prefill_predict_length,
519571
attention_kernel=cfg.attention,
520-
inputs_q_shape=inputs_q_shape,
521-
inputs_kv_shape=inputs_kv_shape,
572+
inputs_q_shape=dummy_inputs_shape,
573+
inputs_kv_shape=dummy_inputs_shape,
522574
out_axis_names=(BATCH, LENGTH_NO_EXP, EMBED),
523575
mesh=self.mesh,
524576
dtype=cfg.dtype,

0 commit comments

Comments
 (0)